play_char training is broken; CharDataset is not multiprocessing compatible
See original GitHub issueI discovered that the CharDataset
implementation is broken and returns the same batch of data multiple times in a row. This causes massive overfitting and wasted cycles.
The root issue is that CharDataset
is not multiprocessing compatible, but num_workers
is >1 so it’s used in multiprocessing mode.
Details
The source of the problem is this line in __getitem__
:
i = np.random.randint(0, len(self.data) - (self.block_size + 1))
CharDataset
is fed into a DataLoader
during training, with num_workers
set greater than 1. This puts DataLoader
into multiprocessing mode where it distributes the Dataset
to multiple processes. The crux of the issue is that in doing so it copies over local state, including for example the state of random number generators. So that line above will return the exact same sequence of “random” indexes in every worker process. This results in the same batch of data being repeated four times in a row, before repeating the next batch of data four times, and so on.
Here is a notebook that simplifies play_char
to demonstrate the issue: https://gist.github.com/fpgaminer/7737a9377e3379fe17dc5bb83d4db69c
In the simplified notebook __getitem__
returns i
directly. In the last cell it iterates the loader
and prints out the batches. As can be seen, batches are repeated four times.
Workaround
The workaround for me was to set num_workers
to 1. Before the workaround the model showed signs of overfitting on WebText2 which shouldn’t be possible. After the workaround, the model started to train correctly and test loss began dropping as expected.
Fix
I haven’t worked with raw PyTorch much, so I don’t know the idiomatic fix. I’m happy to research and propose a pull request if you would like. Perhaps the easiest fix is to use the workaround and drop an assert into CharDataset
to throw if multiprocessing gets used. Since the dataset is in-memory there’s little reason to use multiple workers. Larger datasets would really need a different Dataset implementation anyway.
Issue Analytics
- State:
- Created 3 years ago
- Comments:11 (11 by maintainers)
Top GitHub Comments
addressed in https://github.com/karpathy/minGPT/commit/339f4e7ad39558bfd7e99d916b9fdd6c6827f807
as mentioned in commit i’m not happy with this demo still, and i’m not happy that epochs will now take a super long time. Closing the issue for now.
Depends. The old intended behavior is to sample the dataset at random indexes, for
len(data) / block_size)
samples. So we’d need DataLoader and co to feed__getitem__
with indexes from 0 to(len(data) - block_size - 1)
. That means__len__
needs to returnlen(data)
soDataLoader
can do that. ButDataLoader
doesn’t have anum_samples
option. So it’s going to run the epoch forlen(data)
samples, or basically 128x longer than currently (or whatever you set block_size to).At least, as far as I can tell from PyTorch’s docs.
To recreate the original intended behavior we need to manually feed a RandomSampler with
num_samples
set tolen(data) / block_size
to theDataLoader
.The alternative is to have
CharDataset
pre-chunk the data. Then__len__
can stay the same, and__getitem__
just usesself.chunked_data[idx]
to grab a sample. That’s different behavior though. Every epoch you’ll feed the same “chunks”, albeit in a different order. Whereas the current behavior causes the chunks to be chopped out of the data randomly. Maybe it doesn’t matter, I’m not sure. I figure it would cause issues on smaller dataset.