Multiple Workers with PyTorch Dataloader
See original GitHub issueFirst off, JAX is a great library, love using it.
My issue is that my PyTorch dataloader freezes whenever I use >0 workers. The dataset itself works with jnp.array
s. How can I fix this? Here’s a minimal example of what I’m doing:
import os
import jax.numpy as jnp
import numpy as np
import torch
import torchvision
class SpecialMNIST(torch.utils.data.Dataset):
def __init__(self, train, seed=0):
super().__init__()
self._data = torchvision.datasets.MNIST(os.getcwd(), train, download=True)
self._data_len = len(self._data)
def __getitem__(self, index):
img, label = self._data[index]
return jnp.asarray(np.asarray(img)), label
def __len__(self):
return self._data_len
def collate_fn(batch):
if isinstance(batch[0], jnp.ndarray):
return jnp.stack(batch)
elif isinstance(batch[0], (tuple, list)):
return type(batch[0])(collate_fn(samples) for samples in zip(*batch))
else:
return jnp.asarray(batch)
Then following code works:
dataset = SpecialMNIST(train=False)
dataloader = torch.utils.data.DataLoader(dataset,
collate_fn=collate_fn,
num_workers=0)
next(iter(dataloader))
But the following code just hangs:
dataset = SpecialMNIST(train=False)
dataloader = torch.utils.data.DataLoader(dataset,
collate_fn=collate_fn,
num_workers=1)
next(iter(dataloader))
The problem goes away when I change the jnp arrays to np arrays in the dataloader and collate function. However, in my actual use case, I have some complex data augmentation that I would like to JIT compile and run on the CPU. Is there any way to do this in JAX? Or do I have to stick to normal numpy functions when data loading? Thanks!
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:6 (4 by maintainers)
Top Results From Across the Web
DataLoader efficiency with multiple workers - PyTorch Forums
Hi,. I have noticed that my dataloader gets slower if I add more workers compared to num_workers=0. My dataset definition is quite simple:...
Read more >Iterable pytorch dataset with multiple workers - Stack Overflow
Is there a way to solve this issue with pytorch? So a dataloader can be created to not load all file in memory...
Read more >DataLoader with multiple workers leaks memory
This bug is a good opportunity to talk about DataSet/DataLoader design in PyTorch, fork and copy-on-write memory in Linux and Python ...
Read more >DataLoaders Explained: Building a Multi-Process Data Loader ...
utils.data.DataLoader for PyTorch, or a tf.data.Dataset for Tensorflow. These structures leverage parallel processing and pre-fetching in order ...
Read more >Use with PyTorch - Hugging Face
Use multiple Workers ... You can parallelize data loading with the num_workers argument of a PyTorch DataLoader and get a higher throughput. Under...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@n2cholas would adding
at the top of your script help? What you’re seeing might be due to the XLA state being inconsistent after forking off the data loading processes. It will make their start-up quite a bit more expensive, but it might at least be more correct. Also, please remember to wrap your code in
if __name__ == '__main__'
if you do that (see themultiprocessing
docs).I think you can use
np.asarray()
to convert to normal numpy after the call to jax, this will avoid creating a copy of the data if I am not mistaken.