Remove `ensure_omnistaging` for newer versions of jax
See original GitHub issueIt seems like the current version of mpi4jax isn’t compatible with the latest version of jax (0.3.15). In particular, it seems like they’ve removed the omnistaging_compatible attribute of jax.config.
I’ve managed to trace down the commit which removed this functionality. From what I can see, there’s no point in having ensure_omnistaging function anymore.
Issue Analytics
- State:
- Created a year ago
- Comments:5
Top Results From Across the Web
jax[cuda] installation replaces current jax version with old jax ...
on my laptop, the latest version of jax, currently jax-0.3.17, gets replaced by jax-0.2.22. I don't think this was happening before.
Read more >Change log - JAX documentation
Mesh and jax.experimental.PartitionSpec are deprecated and will be removed in 3 months. ... Changes. Update Colab TPU driver version for new jaxlib release....
Read more >Unable to Install Specific JAX jaxlib GPU version
This error appears to be from a new check in pip version 20.3.X and higher, likely related to the new dependency resolver.
Read more >New plan proposed to remove Jacksonville's remaining ...
— A new plan is in the works to remove Confederate statues in the City of Jacksonville, which local activists say symbolize racism...
Read more >Jacksonville Today - Sign up!
Sign up for your local weekday newsletter for news and ways to get involved in Northeast Florida. Monday through Friday, we'll drop the...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Thanks for letting us know. This is already addressed in #151, which is blocked at the moment due to a regression in JAX 0.3.15. For now, JAX 0.3.15 is incompatible with
mpi4jax. Please downgrade to 0.3.14.For the OOM errors you will probably need to prevent JAX from gobbling up all GPU memory (MPI needs some space to operate). For example by setting
Regarding the warning, please see our docs.