Flax BERT finetuning notebook no longer works on TPUs
See original GitHub issueSystem Info
- Colab
transformers
version: 4.22.0.dev0- Platform: Linux-5.10.133±x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.13
- Huggingface_hub version: 0.9.1
- PyTorch version (GPU?): 1.12.1+cu113 (False)
- Tensorflow version (GPU?): 2.8.2 (False)
- Flax version (CPU?/GPU?/TPU?): 0.6.0 (cpu)
- Jax version: 0.3.17
- JaxLib version: 0.3.15
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: yes
- Using TPU: yes
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
The problem arises with the official notebook examples/text_classification_flax.ipynb.
The official notebook has some trivial problems (i.e., gradient_transformation
is never defined) which are fixed in this slightly modified version.
The notebook gets stuck on compiling at the training loop, and exits with this error:
Epoch ...: 0%
0/3 [00:00<?, ?it/s]
Training...: 0%
0/267 [00:00<?, ?it/s]
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-33-e147f5aff5fe> in <module>
5 with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
----> 6 for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
7 state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
17 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Compile failed to finish within 1 hour.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
XlaRuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/random.py in permutation(key, x, axis, independent)
413 raise TypeError("x must be an integer or at least 1-dimensional")
414 r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()')
--> 415 return _shuffle(key, jnp.arange(r), axis)
416 if independent or np.ndim(x) == 1:
417 return _shuffle(key, x, axis)
XlaRuntimeError: INTERNAL: Compile failed to finish within 1 hour.
Expected behavior
The training is supposed to go smoothly. 😄
Issue Analytics
- State:
- Created a year ago
- Comments:14 (5 by maintainers)
Top Results From Across the Web
Practical JAX : Using Hugging Face BERT on TPUs - YouTube
A look at the Hugging Face BERT code, written in JAX / FAX, being fine-tuned on Google's Colab using Google TPUs (Tensor Processing...
Read more >BERT + TPUs + JAX + HuggingFace - Kaggle
HuggingFace is now porting all its models to Flax library which is a JAX based Neural Network library developed by Google Brain.
Read more >BERT - Hugging Face
A blog post on how to use Hugging Face Transformers with Keras: Fine-tune a non-English BERT for Named Entity Recognition. A notebook for...
Read more >Fine-tuning a Transformers model on TPU with Flax/JAX
For reproducibility, we set the random seed to 0 in this notebook. The only thing we have to specify in the config is...
Read more >BERT fine-tuning with Estimators on TPUs on colab TypeError ...
Setting the parameter, drop_remainder to True in the function, input_fn_builder has resolved the issue. Respective Code Snippet is shown ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I am very happy that it worked @NightMachinery ! I think that it makes sense here to have a “reference” colab where people can refer to it - pinging @patil-suraj (for the fix I borrowed from the diffusers notebook) and @LysandreJik regarding the PR that you have suggested 😉 Thank you!
This works! I think the only difference with my previous code is supplying
tpu_driver_20221011
tosetup_tpu
. Where is that documented? I suggest having a central Colab TPU guide on HuggingFace docs which documents things like these that are necessary to run any TPU notebook.Do you want me to send a PR for this specific notebook?