question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Multiple Workers with PyTorch Dataloader

See original GitHub issue

First off, JAX is a great library, love using it.

My issue is that my PyTorch dataloader freezes whenever I use >0 workers. The dataset itself works with jnp.arrays. How can I fix this? Here’s a minimal example of what I’m doing:

import os
import jax.numpy as jnp
import numpy as np
import torch
import torchvision


class SpecialMNIST(torch.utils.data.Dataset):
    def __init__(self, train, seed=0):
        super().__init__()
        self._data = torchvision.datasets.MNIST(os.getcwd(), train, download=True)
        self._data_len = len(self._data)

    def __getitem__(self, index):
        img, label = self._data[index]
        return jnp.asarray(np.asarray(img)), label

    def __len__(self):
        return self._data_len


def collate_fn(batch):
    if isinstance(batch[0], jnp.ndarray):
        return jnp.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        return type(batch[0])(collate_fn(samples) for samples in zip(*batch))
    else:
        return jnp.asarray(batch)

Then following code works:

dataset = SpecialMNIST(train=False)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         collate_fn=collate_fn,
                                         num_workers=0)
next(iter(dataloader))

But the following code just hangs:

dataset = SpecialMNIST(train=False)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         collate_fn=collate_fn,
                                         num_workers=1)
next(iter(dataloader))

The problem goes away when I change the jnp arrays to np arrays in the dataloader and collate function. However, in my actual use case, I have some complex data augmentation that I would like to JIT compile and run on the CPU. Is there any way to do this in JAX? Or do I have to stick to normal numpy functions when data loading? Thanks!

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:3
  • Comments:6 (4 by maintainers)

github_iconTop GitHub Comments

5reactions
apaszkecommented, Jun 10, 2020

@n2cholas would adding

import torch.multiprocessing as multiprocessing
multiprocessing.set_start_method('spawn')

at the top of your script help? What you’re seeing might be due to the XLA state being inconsistent after forking off the data loading processes. It will make their start-up quite a bit more expensive, but it might at least be more correct. Also, please remember to wrap your code in if __name__ == '__main__' if you do that (see the multiprocessing docs).

1reaction
cgarciaecommented, Jun 9, 2020

I think you can use np.asarray() to convert to normal numpy after the call to jax, this will avoid creating a copy of the data if I am not mistaken.

Read more comments on GitHub >

github_iconTop Results From Across the Web

DataLoader efficiency with multiple workers - PyTorch Forums
Hi,. I have noticed that my dataloader gets slower if I add more workers compared to num_workers=0. My dataset definition is quite simple:...
Read more >
Iterable pytorch dataset with multiple workers - Stack Overflow
Is there a way to solve this issue with pytorch? So a dataloader can be created to not load all file in memory...
Read more >
DataLoader with multiple workers leaks memory
This bug is a good opportunity to talk about DataSet/DataLoader design in PyTorch, fork and copy-on-write memory in Linux and Python ...
Read more >
DataLoaders Explained: Building a Multi-Process Data Loader ...
utils.data.DataLoader for PyTorch, or a tf.data.Dataset for Tensorflow. These structures leverage parallel processing and pre-fetching in order ...
Read more >
Use with PyTorch - Hugging Face
Use multiple Workers ... You can parallelize data loading with the num_workers argument of a PyTorch DataLoader and get a higher throughput. Under...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found