Running the run_mlm_flax on TPU v4 pods
See original GitHub issueSystem Info
transformers 4.24.0
Who can help?
I am having problems scaling the run_mlm_flax scripts so that they run on TPU VM v4 Pods (ie the v4-16, v4-32 etc). When running “out of the box”, the performance is exactly the same as when running on a v4-8. To me this indicates that I am feeding a lot of empty data. The max per_device_train_batch_size
for 512 sequences in RoBERTa is 62 in both cases, but since the output is identical, it is obviously not scaling.
From trying to understand the code, it seems to be logical to multiply the batch size here with the jax.process_count()
(src example). However, this does not seem to be the way to approach it.
Any ideas about how to approach this? Is the script tested on v4s?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
See explanation above.
Expected behavior
Expect the batch size to scale automatically.
Issue Analytics
- State:
- Created 10 months ago
- Comments:27 (27 by maintainers)
@sanchit-gandhi: I now have a working version that runs decently fast on the pods! I am down from 220 sec/it to around 10s/it on a v4-128.
I made the following change to the streaming code:
For some reason this is a lot faster, and fast enough to be “useful”. I still do not think this is optimal though. Tokenising and grouping is still slowing down the training considerably when you are using a streaming dataset.
You can use something like a torch DataLoader with num_workers > 0 with your streaming dataset. This way you load and collate the data in parallel to your forward and backward passes.