Feature request: add jax[cudaversion] to pypi
See original GitHub issueHi there,
I’m not quite sure how people generally install the correct cuda version of jaxlib
, but after some searching I found this notebook: https://www.kaggle.com/guillemkami/getting-started-with-nlp-using-jax
It explains pretty well how to install the correct version compatible with your CUDA installation.
Now I was thinking, this process could be made much simpler by adding the different cuda versions to pypi. For instance, suppose my CUDA version is 10.1, then I install the right version of jaxlib
by running:
$ pip3 install jaxlib[cuda101]
Or, in a CI/CD pipeline (like travis or github actions) I might choose to install the CPU version
$ pip3 install jaxlib[nocuda]
~I believe that the square brackets are treated as ordinary characters in pypi, which would mean that you can upload a new cuda version simply by registering a new package name with pypi.org.~
I had another look, they are added through the extras_require=...
kwarg in setup.py.
Also, JAX is awesome, keep up the good work! -Kris
Issue Analytics
- State:
- Created 4 years ago
- Comments:10 (6 by maintainers)
It isn’t quite what this issue suggested, but note the GPU
jaxlib
installation should now be significantly simpler, see: https://github.com/google/jax#pip-installation#3555 solves one of the problems (manylinux2010 compilance), but not all of them. The others are: