Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Running the run_mlm_flax on TPU v4 pods

See original GitHub issue

System Info

transformers 4.24.0

Who can help?


I am having problems scaling the run_mlm_flax scripts so that they run on TPU VM v4 Pods (ie the v4-16, v4-32 etc). When running “out of the box”, the performance is exactly the same as when running on a v4-8. To me this indicates that I am feeding a lot of empty data. The max per_device_train_batch_size for 512 sequences in RoBERTa is 62 in both cases, but since the output is identical, it is obviously not scaling.

From trying to understand the code, it seems to be logical to multiply the batch size here with the jax.process_count() (src example). However, this does not seem to be the way to approach it.

Any ideas about how to approach this? Is the script tested on v4s?


  • The official example scripts
  • My own modified scripts


  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)


See explanation above.

Expected behavior

Expect the batch size to scale automatically.

Issue Analytics

  • State:open
  • Created 10 months ago
  • Comments:27 (27 by maintainers)

github_iconTop GitHub Comments

peregilkcommented, Nov 28, 2022

@sanchit-gandhi: I now have a working version that runs decently fast on the pods! I am down from 220 sec/it to around 10s/it on a v4-128.

I made the following change to the streaming code:

# samples = {
#    k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
# }
samples["input_ids"] += tokenized_samples["input_ids"]
samples["attention_mask"] += tokenized_samples["attention_mask"]
samples["special_tokens_mask"] += tokenized_samples["special_tokens_mask"]

For some reason this is a lot faster, and fast enough to be “useful”. I still do not think this is optimal though. Tokenising and grouping is still slowing down the training considerably when you are using a streaming dataset.

lhoestqcommented, Dec 20, 2022

You can use something like a torch DataLoader with num_workers > 0 with your streaming dataset. This way you load and collate the data in parallel to your forward and backward passes.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Cloud TPU v4 Pods, large model training, MLPerf v1.1
We were able to achieve an end-to-end training time of ~55 hours for the 480B parameter model and ~40 hours for the 200B...
Read more >
Cloud TPU v4: Fast, flexible, and easy-to-use ML accelerators
Cloud TPUs are enabling faster training of machine learning models and more flexible use cases, such as NLP as well as easy-to-use APIs....
Read more >
Google launches a 9 exaflop cluster of Cloud TPU v4 pods ...
Google says users will be able to slice and dice the new cloud TPU v4 cluster and its pods to meet their needs,...
Read more >
Jax doesn't scale correctly on a TPU Pod. #10057 - GitHub
running flax t5 in a TPU V4-64 pod reduces the per device batch size to 4 and makes the global batch size to...
Read more >
Google Cloud's New TPU v4 ML Hub Packs 9 Exaflops of AI
“This machine learning hub has eight Cloud TPU v4 Pods, custom-built on the same networking infrastructure that powers Google's largest neural ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Post

No results found

github_iconTop Related Hashnode Post

No results found