question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Petastorm sharding + Distributed PyTorch

See original GitHub issue

Problem: I would like to train a PyTorch model on a Parquet dataset in a distributed (multi-GPU, multi-machine) setup, for a fixed number of epochs. For this, I need to shard the dataset and I hoped providing Petastorm’s cur_shard and shard_count would be sufficient. I create Petastorm reader with num_epochs=1 each epoch (or could create once and reset()).

    for epoch in range(epochs):
        train_loader = petastorm.pytorch.DataLoader(petastorm.make_reader(
            ...
            num_epochs=1,
            cur_shard=ctx.global_rank,
            shard_count=ctx.world_size,
        ))

        for i, (inputs, targets) in enumerate(train_loader):
            inputs = inputs.to(ctx.device)
            targets = targets.to(ctx.device)
           
            predictions = model(inputs)
            loss = loss_fn(predictions, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

But when training a DistributedDataParallel model, PyTorch expects shards to have the same number of examples, so that they have the same number of batches, so that all ranks make the same number of training steps, so that they participate in the same number of allreduces. E.g. torch.utils.data.distributed.DistributedSampler (used to implement sharding in stock PyTorch’s DataLoader) wraps the dataset around to make it evenly divisible by the number of shards.

If shards are not even-sized, some ranks have less work to do, finishing their shard early, and start the next epoch while the rest of the ranks still process the previous epoch. If we’re training for certain number of epochs, the “fast ranks” eventually finish first and terminate, leaving the rest of the ranks hanging because allreduce is now impossible.

In Petastorm, the even-sized shards are not guaranteed. The len of dataset is unknown and we don’t have random-access to rows. Rowgroups are assigned to shards in a round-robin fashion without wraparound, so one rank can get more rowgroups than the other. Moreover, rowgroups might not have the same number of rows, and applying row predicates can change the balance further.

Possible solutions: I thought about how to make it work.

  • Option 1 is to scan dataset once at runtime before training, figure out which rows match the predicate, wrap around if needed to make total size divisible, and then assign row indexes to shards in a way that each shard has the same number of rows. The unit of work for Petastorm then becomes rowgroup index + indexes of rows within rowgroups, so effectively a poor man’s random-access-to-row implementation, which would be pretty slow. Alternatively can assign the whole rowgroups to shards, while keeping them even-sized, which a sort of bin-packing problem.

  • Option 2 is to use Petastorm sharding, but set infinite epochs, so that “fast ranks” wrap around, and cut off the epoch in the train-loop outside of Petastorm. But then the smaller shards effectively get oversampled. Wouldn’t help with empty shards.

  • Option 3 is to not shard at all, make each rank read the whole dataset, but shuffle in a unique way independent of the other ranks. The resulting mega-epoch will be as long as num_shards of normal epochs. If shuffling is good, each global batch will be effectively drawing a batch-sized random sample (without replacement) num_shards times (with replacement).

@selitvin is my assessment of the problem valid? Do you have some solutions in mind? Thank you!

Issue Analytics

  • State:open
  • Created 4 years ago
  • Reactions:4
  • Comments:14 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
selitvincommented, Mar 20, 2020

IterableDataset is a very interesting direction. Having a more pytorch native way of parallel loading/processing could probably be used to substitute petastorm custom worker-pools and make pytorch users experience more pytorch look&feel as well as improve performance due to a better, shared-memory-based, IPC communication mechanism already inplace in pytorch.

0reactions
zxgxcommented, Aug 25, 2022

I’m too busy to handle dataloader recently😵 that’s why I resort to petastorm. I may have a look into the source code in the future, but I would recommend fixing this load balancing issue for DDP first as this is a very common demand in large scale datasets and distributed training.

Read more comments on GitHub >

github_iconTop Results From Across the Web

petastorm-spark-converter-pytorch - Databricks
Convert the Spark DataFrame to a PyTorch DataLoader using petastorm spark_dataset_converter ... Feed the data into a distributed PyTorch model for training.
Read more >
Release notes — petastorm 0.12.0 documentation
PR 771: Update pytorch mnist example with up-to-date make_reader() interface. ... use seed field to apply randomization effects on sharding row groups.
Read more >
Petastorm - QCon.ai
Petastorm. Scalable. Native TensorFlow, PyTorch. Shuffling. Sharding. Queries, Indexing. Parquet partitions. Local caching. N-grams (windowing) ...
Read more >
Introducing Petastorm: Uber ATG's Data Access Library for ...
Sharding for distributed training ... Figure 5. Petastorm feeds non-overlapping subsets of a dataset to different machines participating in a ...
Read more >
Petastorm: A Simple Approach to Deep Learning Models in ...
Petastorm enables either single machine or distributed training, ... Python-based ML frameworks such as NumPy, Tensorflow, Theano, Pytorch, and PySpark.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found