Proposed Dockerfile for Running JAX
See original GitHub issueDemand for JAX Dockerfile
There is some demand for a Dockerfile that runs jax / jaxlib (as opposed to the one currently in the repository that is used to build wheels for jax/jaxlib), see here, here, and probably here. I’ve separately pointed out that tensorflow-gpu images coincidentally work fine.
tf-gpu is somewhat of a moving target, and it’s probably wise to de-couple any jax Dockerfile from what tf-gpu does. As part of a package I’m writing on top of jax, I’ve put together a Dockerfile with “only what jaxlib needs”, using the nvidia cuda images as a base (specifically the Ubuntu 20.04 cudnn devel image). I’m interested in having this upstreamed - in exchange, I’m committed to maintaining the Dockerfile for at least some time. First, I present the Dockerfile. Then I’ll lay out some of the design choices (that’s of course flexible to whatever the jax team wants).
Pre-Requisites for use
Use of the image pre-supposes the user has installed the nvidia docker toolkit on their docker host. This has been tested on a RHEL8 docker host on consumer hardware.
Standalone Dockerfile
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04
# declare the image name
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
# declare what jaxlib tag to use
# if a CI/CD system is expected to pass in these arguments
# the dockerfile should be modified accordingly
JAXLIB_VERSION=0.1.62
# install python3-pip
RUN apt update && apt install python3-pip -y
# install dependencies via pip
RUN pip3 install numpy scipy six wheel jaxlib==${JAXLIB_VERSION}+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html
I have no idea how this can integrate into the CI/CD pipeline that the jax team have for the project (although I have looked at some of the files in build/
, but am happy to work with the team on integration.
Design Choices
The Dockerfile is intended to abstract away the trickiest part of getting jax+cuda working, which is jaxlib. As such it does not include jax itself (this simplifies the build matrix, since only every supported jaxlib version would need to be built, not every jaxlib-jax tuple). It is intended that users will pull this image as part of their own tooling, and append whatever jax version they need (either from source or pypi).
I’ve also chosen to install the resulting jaxlib wheel globally (as opposed, for example, to creating a new “jax” user). This is similar to how the tf-gpu images work.
Legal Stuff
The use of the nvidia images is bound by their CUDA and CUDNN EULAs, and probably by Ubuntu licensing terms too. Not a lawyer, but whatever tf-gpu does (if any) to make their images compliant will probably work for the Dockerfile I’ve outlined above.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:11
- Comments:18 (1 by maintainers)
Top GitHub Comments
Thanks for this, it was the only way i was able to get GPU backend running locally. For anyone reading now i had to modify the container as such:
The NVIDIA JAX early access (EA) container is now released, register to get it 😃