Petastorm sharding + Distributed PyTorch
See original GitHub issueProblem:
I would like to train a PyTorch model on a Parquet dataset in a distributed (multi-GPU, multi-machine) setup, for a fixed number of epochs. For this, I need to shard the dataset and I hoped providing Petastorm’s cur_shard
and shard_count
would be sufficient. I create Petastorm reader with num_epochs=1
each epoch (or could create once and reset()
).
for epoch in range(epochs):
train_loader = petastorm.pytorch.DataLoader(petastorm.make_reader(
...
num_epochs=1,
cur_shard=ctx.global_rank,
shard_count=ctx.world_size,
))
for i, (inputs, targets) in enumerate(train_loader):
inputs = inputs.to(ctx.device)
targets = targets.to(ctx.device)
predictions = model(inputs)
loss = loss_fn(predictions, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
But when training a DistributedDataParallel
model, PyTorch expects shards to have the same number of examples, so that they have the same number of batches, so that all ranks make the same number of training steps, so that they participate in the same number of allreduces. E.g. torch.utils.data.distributed.DistributedSampler
(used to implement sharding in stock PyTorch’s DataLoader) wraps the dataset around to make it evenly divisible by the number of shards.
If shards are not even-sized, some ranks have less work to do, finishing their shard early, and start the next epoch while the rest of the ranks still process the previous epoch. If we’re training for certain number of epochs, the “fast ranks” eventually finish first and terminate, leaving the rest of the ranks hanging because allreduce is now impossible.
In Petastorm, the even-sized shards are not guaranteed. The len of dataset is unknown and we don’t have random-access to rows. Rowgroups are assigned to shards in a round-robin fashion without wraparound, so one rank can get more rowgroups than the other. Moreover, rowgroups might not have the same number of rows, and applying row predicates can change the balance further.
Possible solutions: I thought about how to make it work.
-
Option 1 is to scan dataset once at runtime before training, figure out which rows match the predicate, wrap around if needed to make total size divisible, and then assign row indexes to shards in a way that each shard has the same number of rows. The unit of work for Petastorm then becomes
rowgroup index + indexes of rows within rowgroups
, so effectively a poor man’s random-access-to-row implementation, which would be pretty slow. Alternatively can assign the whole rowgroups to shards, while keeping them even-sized, which a sort of bin-packing problem. -
Option 2 is to use Petastorm sharding, but set infinite epochs, so that “fast ranks” wrap around, and cut off the epoch in the train-loop outside of Petastorm. But then the smaller shards effectively get oversampled. Wouldn’t help with empty shards.
-
Option 3 is to not shard at all, make each rank read the whole dataset, but shuffle in a unique way independent of the other ranks. The resulting mega-epoch will be as long as
num_shards
of normal epochs. If shuffling is good, each global batch will be effectively drawing a batch-sized random sample (without replacement)num_shards
times (with replacement).
@selitvin is my assessment of the problem valid? Do you have some solutions in mind? Thank you!
Issue Analytics
- State:
- Created 4 years ago
- Reactions:4
- Comments:14 (2 by maintainers)
Top GitHub Comments
IterableDataset
is a very interesting direction. Having a more pytorch native way of parallel loading/processing could probably be used to substitute petastorm custom worker-pools and make pytorch users experience more pytorch look&feel as well as improve performance due to a better, shared-memory-based, IPC communication mechanism already inplace in pytorch.I’m too busy to handle dataloader recently😵 that’s why I resort to petastorm. I may have a look into the source code in the future, but I would recommend fixing this load balancing issue for DDP first as this is a very common demand in large scale datasets and distributed training.