random_split looks that it cannot generate label-balanced sub-datasets
See original GitHub issue🐛 Bug
The pytorch-bultin function torch.utils.data.random_split is used in multiple DataModules. However, this function implementation is not correct, and it cannot generate label-balanced sub-datasets.
To Reproduce
Steps to reproduce the behavior:
run the following code:
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule('/localhome/fair/Dataset/cifar10')
stat = [0 for i in range(10)]
for batch in dm.train_dataloader():
inputs, targets = batch
for b in range(targets.size()[0]):
stat[targets[b].item()] += 1
stat
and it will output:
[4512, 4486, 4466, 4529, 4528, 4485, 4493, 4499, 4495, 4499]
Expected behavior
We want a label-balanced output. That is to say, the sample label distribution of the split sub-datasets should have the same proportion of the original dataset.
[4500, 4500, 4500, 4500, 4500, 4500, 4500, 4500, 4500]
Environment
- PyTorch Version: 1.6
- OS: Ubuntu 18.04 LTS
- How you installed PyTorch (
conda,pip, source): conda - Build command you used (if compiling from source): N/A
- Python version: 3.8
- CUDA/cuDNN version: 10.2
- GPU models and configuration: 2080Ti
- Any other relevant information:
Additional context
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
Spark Under the Hood: randomSplit() and sample() Inner ...
Spark utilizes Bernoulli sampling, which can be summarized as generating random numbers for an item (data point) and accepting it into a split...
Read more >Getting Deterministic Results from Spark's randomSplit Function
We noticed an odd case of nondeterminism in Spark's randomSplit function, which is often used to generate test/train data splits for Machine ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

The bug of random_split has been discussed many times in the community, such as Link. It may produce an inconspicuous skewed val_dataset especially when the dataset is not large enough.
A better implementation is
sklearn.model_selection.train_test_split, here is an example code:Thanks a lot for your great pl_blots. And I find some DataModules such as Imagenet_dataset also depends on sklearn. So looks that it is not a serious problem… maybe?