TPU not initialized when running official `run_mlm_flax.py` example.
See original GitHub issueEnvironment info
transformers
version: 4.9.0.dev0- Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29
- Python version: 3.8.5
- PyTorch version (GPU?): 1.9.0+cu102 (False)
- Tensorflow version (GPU?): 2.5.0 (False)
- Flax version (CPU?/GPU?/TPU?): 0.3.4 (tpu)
- Jax version: 0.2.16
- JaxLib version: 0.1.68
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help
Information
I am setting up a new TPU VM according to the Cloud TPU VM JAX quickstart and the following the installation steps as described here: https://github.com/huggingface/transformers/tree/master/examples/research_projects/jax-projects#how-to-install-relevant-libraries to install flax
, jax
transformers
, and datasets
.
Then, when running a simple example using the run_mlm_flax.py
script, I’m encounting an error/ warning:
INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: TPU Interpreter Host
=> I am now unsure whether the code actually runs on TPU or instead on CPU.
To reproduce
The problem can be easily reproduced by:
- sshing into a TPU, e.g.
patrick-test
(Flax, JAX, & Transformers should already be installed)
If one goes into patrick-test
the libraries are already installed - on an “newly” created TPU VM, one can follow these steps to install the relevant libraries.
- Going to home folder
cd ~/
- creating a new dir:
mkdir test && cd test
- cloning a dummy repo into it
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/patrickvonplaten/norwegian-roberta-als
- Linking the
run_mlm_flax.py
script
ln -s $(realpath ~/transformers/examples/flax/language-modeling/run_mlm_flax.py) ./
- Running the following command (which should show the above warning/error again):
./run_mlm_flax.py \
--output_dir="norwegian-roberta-als" \
--model_type="roberta" \
--config_name="norwegian-roberta-als" \
--tokenizer_name="norwegian-roberta-als" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_als" \
--max_seq_length="128" \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="8" \
--learning_rate="3e-4" \
--overwrite_output_dir \
--num_train_epochs="3"
=>
You should see a console print that says:
[10:15:48] - INFO - absl - Starting the local TPU driver.
[10:15:48] - INFO - absl - Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
[10:15:48] - INFO - absl - Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: "cuda". Available platform names are: TPU Host Interpreter
Expected behavior
I think this warning / error should not be displayed and the TPU should be correctly configured.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:13 (10 by maintainers)
Top GitHub Comments
@erensezener I think a lot has changed in the code here since this was written. I am linking to my internal notes above. I have repeated that one several times, and know it gets a working system up and running.
Just a wild guess: Have you tried setting
export USE_TORCH=False
This solves the issue indeed! Thank you, you saved me many more hours of debugging 😃