Error in "on_advance_start" when data-loader's sampler is a NumPy array
See original GitHub issue🐛 Bug
When using a NumPy array as sampler for a PyTorch data loader the check
if (
dataloader is not None
and getattr(dataloader, "sampler", None)
and callable(getattr(dataloader.sampler, "set_epoch", None))
):
in “on_advance_start”, raises the following exception:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
To Reproduce
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
import numpy as np
class RandomDataset(Dataset):
def __init__(self, size, num_samples):
self.len = num_samples
self.data = torch.randn(num_samples, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2, sampler=np.array([1, 2, 3, 4]))
test_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
run()
Expected behavior
The error is not raised.
Environment
- CUDA:
- GPU:
- Tesla T4
- available: True
- version: 11.3
- GPU:
- Packages:
- numpy: 1.21.6
- pyTorch_debug: False
- pyTorch_version: 1.11.0+cu113
- pytorch-lightning: 1.6.4
- tqdm: 4.64.0
- System:
- OS: Linux
- architecture:
- 64bit
- processor: x86_64
- python: 3.7.13
- version: #1 SMP Sun Apr 24 10:03:06 PDT 2022
Additional context
An easy solution is to change the code that generates the error to
if (
dataloader is not None
and getattr(dataloader, "sampler", None) is not None
and callable(getattr(dataloader.sampler, "set_epoch", None))
):
if the only thing to check is that the sampler exists and is different from None.
cc @borda @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Why pytorch DataLoader behaves differently on numpy array ...
This is because how batching is handled in torch.utils.data.DataLoader . collate_fn argument decides how samples from samples are merged ...
Read more >torch.utils.data — PyTorch 1.13 documentation
When automatic batching is disabled, the default collate_fn simply converts NumPy arrays into PyTorch Tensors, and keeps everything else untouched.
Read more >Using PyTorch + NumPy? You're making a mistake.
A bug that plagues thousands of open-source ML projects. ... import numpy as np from torch.utils.data import Dataset, DataLoader class ...
Read more >How to Create and Use a PyTorch DataLoader
Then you create a Dataset instance and pass it to a DataLoader constructor ... Conversion from NumPy array data to PyTorch tensor data...
Read more >PyTorch [Basics] — Sampling Samplers - Towards Data Science
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, ... Shuffle the list of indices using np.shuffle .
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@BaruchG Thanks for asking, however I think this is already fixed in 1.6.5, as I am not able to reproduce the bug anymore. @carmocca can you confirm this? If that is the case I think the issue can be marked as solved, maybe linking the PR that solved it.
Thanks for the heads up! This was fixed by https://github.com/Lightning-AI/lightning/pull/13396