FLAX-T5 - TPU not found Colab
See original GitHub issueHello,
I’m using the code run_t5_mlm_flax.py
on Google Colab in TPU mode.
I have the following problem:
And also:
/usr/local/lib/python3.7/dist-packages/jax/__init__.py:27: UserWarning: cloud_tpu_init failed: ConnectionError(MaxRetryError("HTTPConnectionPool(host='metadata.google.internal', port=80): Max retries exceeded with url: /computeMetadata/v1/instance/attributes/agent-worker-number (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7f4ff0494790>: Failed to establish a new connection: [Errno 110] Connection timed out'))"))
This a JAX bug; please report an issue at https://github.com/google/jax/issues
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
The TPU is not found, and the code switch in CPU mode. I’m using these libraries:
pip install datasets
pip install transformers
pip install flax
pip install optax
and also this configuration I read:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print(jax.local_devices())
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
unset LD_PRELOAD
USE_TORCH=0
How can I do to use this code on Colab or to use a FLAX-T5 with TPU on Colab?
Thank you!
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (3 by maintainers)
Found the issue. We need to call
in the script before importing anything JAX related. Calling
setup_tpu()
in the colab and then launching the script won’t work because these are two different processes. So adding these two lines in the script before any JAX/Flax import should fix this issue.@patil-suraj Thanks that’s it!