question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Flax BERT finetuning notebook no longer works on TPUs

See original GitHub issue

System Info

  • Colab
  • transformers version: 4.22.0.dev0
  • Platform: Linux-5.10.133±x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.13
  • Huggingface_hub version: 0.9.1
  • PyTorch version (GPU?): 1.12.1+cu113 (False)
  • Tensorflow version (GPU?): 2.8.2 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.6.0 (cpu)
  • Jax version: 0.3.17
  • JaxLib version: 0.3.15
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: yes
  • Using TPU: yes

Who can help?

@patil-suraj @LysandreJik

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

The problem arises with the official notebook examples/text_classification_flax.ipynb.

The official notebook has some trivial problems (i.e., gradient_transformation is never defined) which are fixed in this slightly modified version.

The notebook gets stuck on compiling at the training loop, and exits with this error:

Epoch ...: 0%
0/3 [00:00<?, ?it/s]
Training...: 0%
0/267 [00:00<?, ?it/s]
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-33-e147f5aff5fe> in <module>
      5     with tqdm(total=len(train_dataset) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
----> 6       for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
      7         state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)

17 frames
UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Compile failed to finish within 1 hour.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

XlaRuntimeError                           Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/_src/random.py in permutation(key, x, axis, independent)
    413       raise TypeError("x must be an integer or at least 1-dimensional")
    414     r = core.concrete_or_error(int, x, 'argument x of jax.random.permutation()')
--> 415     return _shuffle(key, jnp.arange(r), axis)
    416   if independent or np.ndim(x) == 1:
    417     return _shuffle(key, x, axis)

XlaRuntimeError: INTERNAL: Compile failed to finish within 1 hour.

Expected behavior

The training is supposed to go smoothly. 😄

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:14 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
younesbelkadacommented, Oct 13, 2022

I am very happy that it worked @NightMachinery ! I think that it makes sense here to have a “reference” colab where people can refer to it - pinging @patil-suraj (for the fix I borrowed from the diffusers notebook) and @LysandreJik regarding the PR that you have suggested 😉 Thank you!

1reaction
NightMachinerycommented, Oct 13, 2022

Hey @NightMachinery ! Can you try with these cells for installation? I think that I gave you the wrong installation guidelines before

#@title Set up JAX
#@markdown If you see an error, make sure you are using a TPU backend. Select `Runtime` in the menu above, then select the option "Change runtime type" and then select `TPU` under the `Hardware accelerator` setting.
!pip install --upgrade jax jaxlib 

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')

!pip install flax diffusers transformers ftfy
jax.devices()

I can confirm jax_devices() gave me


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

This is based on the recent demo from diffusers, see the colab here: colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fast_jax.ipynb

This works! I think the only difference with my previous code is supplying tpu_driver_20221011 to setup_tpu. Where is that documented? I suggest having a central Colab TPU guide on HuggingFace docs which documents things like these that are necessary to run any TPU notebook.

Do you want me to send a PR for this specific notebook?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Practical JAX : Using Hugging Face BERT on TPUs - YouTube
A look at the Hugging Face BERT code, written in JAX / FAX, being fine-tuned on Google's Colab using Google TPUs (Tensor Processing...
Read more >
BERT + TPUs + JAX + HuggingFace - Kaggle
HuggingFace is now porting all its models to Flax library which is a JAX based Neural Network library developed by Google Brain.
Read more >
BERT - Hugging Face
A blog post on how to use Hugging Face Transformers with Keras: Fine-tune a non-English BERT for Named Entity Recognition. A notebook for...
Read more >
Fine-tuning a Transformers model on TPU with Flax/JAX
For reproducibility, we set the random seed to 0 in this notebook. The only thing we have to specify in the config is...
Read more >
BERT fine-tuning with Estimators on TPUs on colab TypeError ...
Setting the parameter, drop_remainder to True in the function, input_fn_builder has resolved the issue. Respective Code Snippet is shown ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found