question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Queue does not take full advantage of multiprocessing

See original GitHub issue

Enhancement

Queue has a num_workers attributes to take advantage of multiprocessing to load faster the samples. However, if samples are loaded in parallel, patches are extracted in the main process whatever the num_workers value is which can greatly augment loading time.

To reproduce

import torch
import torchio
import timeit

image_size = 120, 120, 120
patch_size = 64, 64, 64
dataset = torchio.ImagesDataset([torchio.Subject(
    img=torchio.Image(tensor=torch.rand(image_size))
) for i in range(10)])
sampler = torchio.data.sampler.UniformSampler(patch_size)
queue = torchio.data.Queue(dataset, max_length=100, samples_per_volume=10, sampler=sampler, num_workers=10)

print(timeit.timeit(queue.fill, number=10))

This code yields the result 76.48914761...

Generating patches in workers The code now yields the result 10.45647265...

My change was only to generate the patches inside the collate_fn of the DataLoader. It seems to make loading really faster but to consume more RAM (I guess because now workers that are waiting are not only waiting with a sample but also with patches). Because of this compromise I think it would be a good idea to give the user the option to get both samples and patches using multiprocess or only samples. What do you think?

Here is the change I made to the Queue class:


    def fill(self) -> None:
        ...
        for _ in iterable:
            patches = self.get_next_subject_sample()       # CHANGE HERE
            self.patches_list.extend(patches)
        if self.shuffle_patches:
            random.shuffle(self.patches_list)

    def get_next_subject_sample(self) -> dict:
        # A StopIteration exception is expected when the queue is empty
        try:
            subject_sample = next(self.subjects_iterable)
        except StopIteration as exception:
            self.print('Queue is empty:', exception)
            self.subjects_iterable = self.get_subjects_iterable()
            subject_sample = next(self.subjects_iterable)
        return subject_sample

    def get_subjects_iterable(self) -> Iterator:
        def collate_fn(x):                      # CHANGE HERE
            iterable = self.sampler(x[0])
            return list(islice(iterable, self.samples_per_volume))
        # I need a DataLoader to handle parallelism
        # But this loader is always expected to yield single subject samples
        self.print(
            '\nCreating subjects loader with', self.num_workers, 'workers')
        subjects_loader = DataLoader(
            self.subjects_dataset,
            num_workers=self.num_workers,
            collate_fn=collate_fn,          # CHANGE HERE
            shuffle=self.shuffle_subjects,
        )
        return iter(subjects_loader)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
fepegarcommented, Jun 19, 2020

That’s impressive, thanks @GFabien. I’ve been doing some testing after understanding your code.

Extracting the patches used to be much faster, that’s why I was doing it in the main process. Now it seems very slow, I must have changed the behavior somewhere. I’ll see if I can get shorter times and if not I’ll implement what you suggest.

0reactions
fepegarcommented, Jun 21, 2020

I run out of RAM very quickly using the script below with the changes in #201, so I’m going to close that and this for now. Feel free to reopen.

import time
import multiprocessing as mp

from tqdm import tqdm, trange

import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from torchio import ImagesDataset, Queue, DATA
from torchio.data.sampler import UniformSampler
from torchio.utils import create_dummy_dataset
from torchio.transforms import (
    ZNormalization,
    RandomNoise,
    RandomFlip,
    RandomAffine,
)

# Define training and patches sampling parameters
num_epochs = 4
patch_size = 128
queue_length = 100
samples_per_volume = 10
batch_size = 4

def model(batch, sleep_time=0.1):
    """Dummy function to simulate a forward pass through the network"""
    time.sleep(sleep_time)
    return batch

# Create a dummy dataset in the temporary directory, for this example
subjects_list = create_dummy_dataset(
    num_images=100,
    size_range=(193, 229),
    force=False,
    verbose=True,
)

# Each element of subjects_list is an instance of torchio.Subject:
# subject = Subject(
#     one_image=torchio.Image(path_to_one_image, torchio.INTENSITY),
#     another_image=torchio.Image(path_to_another_image, torchio.INTENSITY),
#     a_label=torchio.Image(path_to_a_label, torchio.LABEL),
# )

# Define transforms for data normalization and augmentation
transforms = (
    ZNormalization(),
    RandomNoise(std=(0, 0.25)),
    RandomAffine(scales=(0.9, 1.1), degrees=10),
    RandomFlip(axes=(0,)),
)
transform = Compose(transforms)
subjects_dataset = ImagesDataset(subjects_list, transform)

sampler = UniformSampler(patch_size)

# Run a benchmark for different numbers of workers
workers = range(mp.cpu_count() + 1)
for num_workers in workers:
    print('Number of workers:', num_workers)

    # Define the dataset as a queue of patches
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        samples_per_volume,
        sampler,
        num_workers=num_workers,
        verbose=True,
    )
    batch_loader = DataLoader(queue_dataset, batch_size=batch_size)

    start = time.time()
    epochs_progress = trange(num_epochs, leave=False)
    for epoch_index in epochs_progress:
        batches_progress = tqdm(batch_loader, leave=False)
        for batch in batches_progress:
            # The keys of batch have been defined in create_dummy_dataset()
            inputs = batch['one_modality'][DATA]
            targets = batch['segmentation'][DATA]
            logits = model(inputs)
    print('Time:', int(time.time() - start), 'seconds')
    print()
Read more comments on GitHub >

github_iconTop Results From Across the Web

multiprocessing queue full - python - Stack Overflow
I'm using concurrent. futures to implement multiprocessing. I am getting a queue. Full error, which is odd because I am only assigning 10...
Read more >
Things I Wish They Told Me About Multiprocessing in Python
The better option is to pass messages using multiprocessing.Queue objects. Queues should be used to pass all data between subprocesses. This ...
Read more >
Pool Limited Queue Processing in Python
Using Pool we can assign as many parallel processes as we like, but only the `processes` number of threads will be active at...
Read more >
Exploiting Choice: Instruction Fetch and Issue on an ...
scheduling and speculative execution, are not sufficient to take full advantage of a wide-issue processor without simultaneous multithreading.
Read more >
Difference between Asymmetric and Symmetric Multiprocessing
In symmetric multiprocessing, the process is taken from the ready queue. 5. Asymmetric multiprocessing systems are cheaper. Symmetric ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found