question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

play_char training is broken; CharDataset is not multiprocessing compatible

See original GitHub issue

I discovered that the CharDataset implementation is broken and returns the same batch of data multiple times in a row. This causes massive overfitting and wasted cycles.

The root issue is that CharDataset is not multiprocessing compatible, but num_workers is >1 so it’s used in multiprocessing mode.

Details

The source of the problem is this line in __getitem__:

        i = np.random.randint(0, len(self.data) - (self.block_size + 1))

CharDataset is fed into a DataLoader during training, with num_workers set greater than 1. This puts DataLoader into multiprocessing mode where it distributes the Dataset to multiple processes. The crux of the issue is that in doing so it copies over local state, including for example the state of random number generators. So that line above will return the exact same sequence of “random” indexes in every worker process. This results in the same batch of data being repeated four times in a row, before repeating the next batch of data four times, and so on.

Here is a notebook that simplifies play_char to demonstrate the issue: https://gist.github.com/fpgaminer/7737a9377e3379fe17dc5bb83d4db69c

In the simplified notebook __getitem__ returns i directly. In the last cell it iterates the loader and prints out the batches. As can be seen, batches are repeated four times.

Workaround

The workaround for me was to set num_workers to 1. Before the workaround the model showed signs of overfitting on WebText2 which shouldn’t be possible. After the workaround, the model started to train correctly and test loss began dropping as expected.

Fix

I haven’t worked with raw PyTorch much, so I don’t know the idiomatic fix. I’m happy to research and propose a pull request if you would like. Perhaps the easiest fix is to use the workaround and drop an assert into CharDataset to throw if multiprocessing gets used. Since the dataset is in-memory there’s little reason to use multiple workers. Larger datasets would really need a different Dataset implementation anyway.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:11 (11 by maintainers)

github_iconTop GitHub Comments

2reactions
karpathycommented, Aug 25, 2020

addressed in https://github.com/karpathy/minGPT/commit/339f4e7ad39558bfd7e99d916b9fdd6c6827f807

as mentioned in commit i’m not happy with this demo still, and i’m not happy that epochs will now take a super long time. Closing the issue for now.

1reaction
fpgaminercommented, Aug 24, 2020

Depends. The old intended behavior is to sample the dataset at random indexes, for len(data) / block_size) samples. So we’d need DataLoader and co to feed __getitem__ with indexes from 0 to (len(data) - block_size - 1). That means __len__ needs to return len(data) so DataLoader can do that. But DataLoader doesn’t have a num_samples option. So it’s going to run the epoch for len(data) samples, or basically 128x longer than currently (or whatever you set block_size to).

At least, as far as I can tell from PyTorch’s docs.

To recreate the original intended behavior we need to manually feed a RandomSampler with num_samples set to len(data) / block_size to the DataLoader.

The alternative is to have CharDataset pre-chunk the data. Then __len__ can stay the same, and __getitem__ just uses self.chunked_data[idx] to grab a sample. That’s different behavior though. Every epoch you’ll feed the same “chunks”, albeit in a different order. Whereas the current behavior causes the chunks to be chopped out of the data randomly. Maybe it doesn’t matter, I’m not sure. I figure it would cause issues on smaller dataset.

Read more comments on GitHub >

github_iconTop Results From Across the Web

RuntimeError on windows trying python multiprocessing
The thing is, I am not launching my threads in the main module. The threads are handled in a separate module inside a...
Read more >
multiprocessing — Process-based parallelism — Python 3.11 ...
Note that objects related to one context may not be compatible with processes for a different context. In particular, locks created using the...
Read more >
Python Multiprocessing Pool: The Complete Guide
The Python Multiprocessing Pool class allows you to create and manage process pools in Python. Although the Multiprocessing Pool has been ...
Read more >
Multiprocessing best practices — PyTorch 1.13 documentation
This allows to implement various training methods, like Hogwild, A3C, or any others that require asynchronous operation. CUDA in multiprocessing. The CUDA ...
Read more >
Distributed multiprocessing.Pool — Ray 2.2.0
Ray supports running distributed python programs with the multiprocessing.Pool API using Ray Actors instead of local ... Pool API is currently supported.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found