ImportError: cannot import name 'xla_data_pb2' - Jax, Broken on Colab after current release
See original GitHub issueI use jax via trax, and today I got an error AttributeError: 'jaxlib.tpu_client_extension.TpuExecutable' object has no attribute 'ExecutePerReplica'
normally jax in Colab I install it using:
!pip install --upgrade -q jax
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
However, when I saw there has been a new release I rolled back to the old one. After doing that via !pip install --upgrade -q jax==0.1.39 jaxlib===0.1.39
I started getting ImportError: cannot import name 'xla_data_pb2'
on import config from jax
which I fixed temporarily by doing url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
. However, then I started getting an error on import jax
instead
I tried doing !pip upgrade pip
but the best I could reach is ImportError: cannot import name 'xla_client'
on import jax
And again, I can install the current jax but that doesn’t work with my version of trax, and I cannot install the older version of jax anymore at all.
Issue Analytics
- State:
- Created 4 years ago
- Comments:7 (3 by maintainers)
Done. I released jax v0.1.60 to pypi, which should fix the problem. Hope that helps!
We have just released jaxlib 0.1.41 yesterday, perhaps the Colab backend is not picking up that version?
Can you try:
I get 0.1.59 and 0.1.41.
Note that
jax
andjaxlib
use separate version numbers, it is not correct to use!pip install --upgrade -q jax==0.1.39 jaxlib===0.1.39