question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Proposed Dockerfile for Running JAX

See original GitHub issue

Demand for JAX Dockerfile

There is some demand for a Dockerfile that runs jax / jaxlib (as opposed to the one currently in the repository that is used to build wheels for jax/jaxlib), see here, here, and probably here. I’ve separately pointed out that tensorflow-gpu images coincidentally work fine.

tf-gpu is somewhat of a moving target, and it’s probably wise to de-couple any jax Dockerfile from what tf-gpu does. As part of a package I’m writing on top of jax, I’ve put together a Dockerfile with “only what jaxlib needs”, using the nvidia cuda images as a base (specifically the Ubuntu 20.04 cudnn devel image). I’m interested in having this upstreamed - in exchange, I’m committed to maintaining the Dockerfile for at least some time. First, I present the Dockerfile. Then I’ll lay out some of the design choices (that’s of course flexible to whatever the jax team wants).

Pre-Requisites for use

Use of the image pre-supposes the user has installed the nvidia docker toolkit on their docker host. This has been tested on a RHEL8 docker host on consumer hardware.

Standalone Dockerfile

FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04

# declare the image name
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
    # declare what jaxlib tag to use
    # if a CI/CD system is expected to pass in these arguments
    # the dockerfile should be modified accordingly
    JAXLIB_VERSION=0.1.62

# install python3-pip
RUN apt update && apt install python3-pip -y

# install dependencies via pip
RUN pip3 install numpy scipy six wheel jaxlib==${JAXLIB_VERSION}+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I have no idea how this can integrate into the CI/CD pipeline that the jax team have for the project (although I have looked at some of the files in build/, but am happy to work with the team on integration.

Design Choices

The Dockerfile is intended to abstract away the trickiest part of getting jax+cuda working, which is jaxlib. As such it does not include jax itself (this simplifies the build matrix, since only every supported jaxlib version would need to be built, not every jaxlib-jax tuple). It is intended that users will pull this image as part of their own tooling, and append whatever jax version they need (either from source or pypi).

I’ve also chosen to install the resulting jaxlib wheel globally (as opposed, for example, to creating a new “jax” user). This is similar to how the tf-gpu images work.

Legal Stuff

The use of the nvidia images is bound by their CUDA and CUDNN EULAs, and probably by Ubuntu licensing terms too. Not a lawyer, but whatever tf-gpu does (if any) to make their images compliant will probably work for the Dockerfile I’ve outlined above.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:11
  • Comments:18 (1 by maintainers)

github_iconTop GitHub Comments

17reactions
MattLangsenkampcommented, Feb 24, 2022

Thanks for this, it was the only way i was able to get GPU backend running locally. For anyone reading now i had to modify the container as such:

FROM nvidia/cuda:11.6.0-devel-ubuntu20.04

# declare the image name
ENV IMG_NAME=11.6.0-devel-ubuntu20.04 \
    # declare what jaxlib tag to use
    # if a CI/CD system is expected to pass in these arguments
    # the dockerfile should be modified accordingly
    JAXLIB_VERSION=0.3.0

# install python3-pip
RUN apt update && apt install python3-pip -y

# install dependencies via pip
RUN pip3 install numpy scipy six wheel jaxlib==${JAXLIB_VERSION}+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_releases.html jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html
6reactions
mjsMLcommented, Nov 2, 2022

The NVIDIA JAX early access (EA) container is now released, register to get it 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

rocm/jax - Docker Image
JAX -ROCm Docker Images: This repository contains the latest support for JAX in ROCm framework. Here's the recommended command to launch JAX-ROCm docker...
Read more >
Building from source - JAX documentation
Building JAX involves two steps: Building or installing jaxlib , the C++ support library for jax . Installing the jax Python package.
Read more >
2 Building WebLogic Server Images on Docker
Building a Sample Docker Image of a WebLogic Server Domain · Make sure you have oracle/weblogic:12.1. · Change to the /samples/1213-domain directory and...
Read more >
Train and deploy deep learning models using JAX with ...
The Docker image is built on top of a CUDA-enabled container provided by NVIDIA. To ensure that the jaxlibpackage that underlies the ...
Read more >
trax-ml/community - Gitter
@nkitaev Re "import trax" crashes: Tried importing only jax and jaxlib and those didnt break: pb-laptop@pblaptop:~/dev/git/prinvision$ sudo docker run -it ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found