TF2 DeBERTaV2 runs super slow on TPUs
See original GitHub issueSystem Info
latest version of transformers, Colab TPU, tensorflow 2
Who can help?
@kamalkraj @Rocketknight1 @BigBird01
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
It’s currently hard to share code and access to the google bucket. But I believe any TF2 DeBERTaV2 code running on TPUs will have this issue
Expected behavior
I’ve been trying to train a deberta v3 model on GPU and TPUs. I got it to work on multi-node and multi-gpus using Nvidia deeplearning examples libraries https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/ I basically used the training setup and loop from the BERT code, the dataset utils from the ELECTRA code, and the model from Huggingface transformers with some changes in order to share embeddings.
On 6xA40 45gb gpus i get around 1370 sentences per seconds during training (which is lower than what Nvidia gets for Electra but it’s fine).
Ok, now the problem… on TPU i get 20 sentences per second
I traced the issue back to the tf.gather function here https://github.com/huggingface/transformers/blob/main/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525
I ran TPU profiling and this is the output:
GatherV2 takes most of the time:
zoomed in pictures of the fast ops
Also, I’m not sure if this is TPU specific since on GPUs the training ~30% slower compared to regular ELECTRA.
Issue Analytics
- State:
- Created a year ago
- Comments:34 (31 by maintainers)
Top GitHub Comments
Only for JAX on TPU, I’ll ask around and see if there is anyone who can help with TF!
For JAX BLOOM we couldn’t even compile the 176B parameter model with the naive implementation of
concatenate_to_cache
, yet alone benchmark which operations consumed the bulk of the execution time! We swapped it for this more efficient implementation (with one-hot encodings etc): https://github.com/huggingface/bloom-jax-inference/blob/2a04aa519d262729d54adef3d19d63879f81ea89/bloom_inference/modeling_bloom/modeling_bloom.py#L119 Coincidentally, we’ve just run the JAX profiler for this implementation and are going through the traceback it with some of the Google JAX guys later today. Will report back on how performance fares!