question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

random_split looks that it cannot generate label-balanced sub-datasets

See original GitHub issue

🐛 Bug

The pytorch-bultin function torch.utils.data.random_split is used in multiple DataModules. However, this function implementation is not correct, and it cannot generate label-balanced sub-datasets.

To Reproduce

Steps to reproduce the behavior:

run the following code:

from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule('/localhome/fair/Dataset/cifar10')
stat = [0 for i in range(10)]
for batch in dm.train_dataloader():
    inputs, targets = batch
    for b in range(targets.size()[0]):
        stat[targets[b].item()] += 1
stat

and it will output:

[4512, 4486, 4466, 4529, 4528, 4485, 4493, 4499, 4495, 4499]

Expected behavior

We want a label-balanced output. That is to say, the sample label distribution of the split sub-datasets should have the same proportion of the original dataset.

[4500, 4500, 4500, 4500, 4500,  4500, 4500, 4500, 4500]

Environment

  • PyTorch Version: 1.6
  • OS: Ubuntu 18.04 LTS
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): N/A
  • Python version: 3.8
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: 2080Ti
  • Any other relevant information:

Additional context

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
zhutmostcommented, Oct 27, 2020

The bug of random_split has been discussed many times in the community, such as Link. It may produce an inconspicuous skewed val_dataset especially when the dataset is not large enough.

A better implementation is sklearn.model_selection.train_test_split, here is an example code:

import numpy as np
from torch.utils.data import Subset, DataLoader
import torchvision as tv
from sklearn.model_selection import train_test_split

def __balance_val_split(dataset, val_split=0.):
    targets = np.array(dataset.targets)
    train_indices, val_indices = train_test_split(
        np.arange(targets.shape[0]),
        test_size=val_split,
        stratify=targets
    )
    train_dataset = Subset(dataset, indices=train_indices)
    val_dataset = Subset(dataset, indices=val_indices)
    return train_dataset, val_dataset
0reactions
zhutmostcommented, Nov 26, 2020

the case why I have moved to PT split was to drop dependency on sklearn as this seems to be the only usage… :] so shall we use it again or just implement the split ourselves?

Thanks a lot for your great pl_blots. And I find some DataModules such as Imagenet_dataset also depends on sklearn. So looks that it is not a serious problem… maybe?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Spark Under the Hood: randomSplit() and sample() Inner ...
Spark utilizes Bernoulli sampling, which can be summarized as generating random numbers for an item (data point) and accepting it into a split...
Read more >
Getting Deterministic Results from Spark's randomSplit Function
We noticed an odd case of nondeterminism in Spark's randomSplit function, which is often used to generate test/train data splits for Machine ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found