AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
See original GitHub issuePlease:
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
Code:
import jax
import jaxlib
import jaxlib.xla_extension as xe
jax.random.PRNGKey(0)
print(jax.devices())
print(xe.CpuDevice)
- If applicable, include full error messages/tracebacks.
Output:
[GpuDevice(id=0, process_index=0)]
Traceback (most recent call last):
File "~/jax/1_installation_test.py", line 8, in <module>
print(xe.CpuDevice)
AttributeError: module 'jaxlib.xla_extension' has no attribute 'CpuDevice'
Environment details:
- Fedora 35 with cuda 11.6, cudnn 8.4
- jax 0.3.7
- jaxlib 0.3.7+cuda11.cudnn82
- python 3.9
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
AttributeError: module 'jaxlib.xla_extension' has no attribute ...
Could anyone please help me fix the following error when going through "/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in ...
Read more >Change log - JAX documentation
The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks. The Python flatbuffers package is no longer a dependency of jaxlib....
Read more >AttributeError: module 'jaxlib' has no attribute 'version'
Now it is giving me an error I have no clue how to solve! AttributeError: module 'jaxlib' has no attribute 'version'. score:1.
Read more >Can't get Numpyro to use GPU on Linux
AttributeError : module 'jaxlib.mlir._mlir_libs._mlir.ir' has no attribute 'DialectRegistry'. The only way my model runs at all is with a ...
Read more >jaxlib
While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler,...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I ran into the same issue and solved it by
pip install chex --upgrade
.I just ran into the exact same error.
Environment details:
Just updated from jax/jaxlib 0.3.6 to 0.3.7, took care to install the right cuda and cudnn version. I’m guessing it has something to do with v0.3.7.
Hoping it gets solved quickly.
Update: Just downgraded to jax 0.3.6 and jaxlib 0.3.5+cuda11.cudnn805 and it seems to work for now.