Queue does not take full advantage of multiprocessing
See original GitHub issueEnhancement
Queue
has a num_workers
attributes to take advantage of multiprocessing to load faster the samples. However, if samples are loaded in parallel, patches are extracted in the main process whatever the num_workers
value is which can greatly augment loading time.
To reproduce
import torch
import torchio
import timeit
image_size = 120, 120, 120
patch_size = 64, 64, 64
dataset = torchio.ImagesDataset([torchio.Subject(
img=torchio.Image(tensor=torch.rand(image_size))
) for i in range(10)])
sampler = torchio.data.sampler.UniformSampler(patch_size)
queue = torchio.data.Queue(dataset, max_length=100, samples_per_volume=10, sampler=sampler, num_workers=10)
print(timeit.timeit(queue.fill, number=10))
This code yields the result 76.48914761...
Generating patches in workers
The code now yields the result 10.45647265...
My change was only to generate the patches inside the collate_fn
of the DataLoader
. It seems to make loading really faster but to consume more RAM (I guess because now workers that are waiting are not only waiting with a sample but also with patches). Because of this compromise I think it would be a good idea to give the user the option to get both samples and patches using multiprocess or only samples. What do you think?
Here is the change I made to the Queue
class:
def fill(self) -> None:
...
for _ in iterable:
patches = self.get_next_subject_sample() # CHANGE HERE
self.patches_list.extend(patches)
if self.shuffle_patches:
random.shuffle(self.patches_list)
def get_next_subject_sample(self) -> dict:
# A StopIteration exception is expected when the queue is empty
try:
subject_sample = next(self.subjects_iterable)
except StopIteration as exception:
self.print('Queue is empty:', exception)
self.subjects_iterable = self.get_subjects_iterable()
subject_sample = next(self.subjects_iterable)
return subject_sample
def get_subjects_iterable(self) -> Iterator:
def collate_fn(x): # CHANGE HERE
iterable = self.sampler(x[0])
return list(islice(iterable, self.samples_per_volume))
# I need a DataLoader to handle parallelism
# But this loader is always expected to yield single subject samples
self.print(
'\nCreating subjects loader with', self.num_workers, 'workers')
subjects_loader = DataLoader(
self.subjects_dataset,
num_workers=self.num_workers,
collate_fn=collate_fn, # CHANGE HERE
shuffle=self.shuffle_subjects,
)
return iter(subjects_loader)
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (5 by maintainers)
Top GitHub Comments
That’s impressive, thanks @GFabien. I’ve been doing some testing after understanding your code.
Extracting the patches used to be much faster, that’s why I was doing it in the main process. Now it seems very slow, I must have changed the behavior somewhere. I’ll see if I can get shorter times and if not I’ll implement what you suggest.
I run out of RAM very quickly using the script below with the changes in #201, so I’m going to close that and this for now. Feel free to reopen.