question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Installing JAX on Arch

See original GitHub issue

Installing JAX on Arch has been surprisingly difficult. I’ve been trying to (re)install it for several hours now, after accidentally updating CUDA. While I got it working last time (with CUDA 11.0), what I did then doesn’t work now.

Steps taken to install

I’m working with the AUR repository using yay:

yay cuda
yay cudnn

It reports the packages are successfully installed, with these versions:

cuda-11.2.1-2
cudnn-8.1.0.77-1

pip is already fully upgraded (21.0.11), so now for JAX:

$ sudo pip install --upgrade --force jax jaxlib==0.1.62+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Successfully installed jax-0.2.10 jaxlib-0.1.62+cuda112

Because JAX expects CUDA at /usr/local/cuda-XX.X, but Arch installs CUDA at /opt/cuda, I create a symbolic link:

sudo ln -s /opt/cuda /usr/local/cuda-11.2

Checking installation

Just as a sanity check, I see if JAX can access devices in a Python shell, which it can:

$ python
>> import jax
>> jax.devices()
[GpuDevice(id=0)]

I then try to run the following example, from Convolutions in JAX:

from jax import numpy as jnp, random, lax

key = random.PRNGKey(1701)

kernel = jnp.zeros((3, 3, 3, 3))
kernel += jnp.array([[1, 1, 0],
                     [1, 0,-1],
                     [0,-1,1]])[:, :, jnp.newaxis, jnp.newaxis]

img = jnp.zeros((1, 200, 198, 3))

for k in range(3):
    x = 30 + 60 * k
    y = 20 + 60 * k
    img = img.at[0, x:x+10, y:y+10, k].set(1)

out = lax.conv(jnp.transpose(img, [0,3,1,2]),
               jnp.transpose(kernel, [3,2,0,1]),
               (1, 1),
               'SAME')

I get the following errors.

2021-03-12 04:30:44.451633: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:780] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call = (f32[1,3,200,198]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,3,200,198]{3,2,1,0} %parameter.1, f32[3,3,3,3]{3,2,1,0} %parameter.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n                      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))\n                      feature_group_count=1\n                      lhs_dilation=(1, 1)\n                      lhs_shape=(1, 3, 200, 198)\n                      padding=((1, 1), (1, 1))\n                      precision=None\n                      rhs_dilation=(1, 1)\n                      rhs_shape=(3, 3, 3, 3)\n                      window_strides=(1, 1) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.
2021-03-12 04:30:44.548471: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1881] Execution of replica 0 failed: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3294): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
Traceback (most recent call last):
  File "/home/kuhlig/Documents/Programming/convolutional-deconvolution/test.py", line 17, in <module>
    out = lax.conv(jnp.transpose(img, [0,3,1,2]),
  File "/usr/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1582, in conv
    return conv_general_dilated(lhs, rhs, window_strides, padding,
  File "/usr/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 600, in conv_general_dilated
    return conv_general_dilated_p.bind(
  File "/usr/lib/python3.9/site-packages/jax/core.py", line 284, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/usr/lib/python3.9/site-packages/jax/core.py", line 622, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/lib/python3.9/site-packages/jax/interpreters/xla.py", line 242, in apply_primitive
    return compiled_fun(*args)
  File "/usr/lib/python3.9/site-packages/jax/interpreters/xla.py", line 360, in _execute_compiled_primitive
    out_bufs = compiled.execute(input_bufs)
RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3294): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'

I’ve read through everything I can, and the only suggestion I can find is that jaxlib, cuda or cudnn versions must mismatch. Unfortunately, they don’t seem to:

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Thu_Jan_28_19:32:09_PST_2021
Cuda compilation tools, release 11.2, V11.2.142
Build cuda_11.2.r11.2/compiler.29558016_0

$ whereis cudnn_version
cudnn_version: /usr/include/cudnn_version.h

$ cat /usr/include/cudnn_version.h
...
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 1
#define CUDNN_PATCHLEVEL 0
...

What I did last time (or thought I did – something got it working, and I might be attributing it to the wrong thing), was create a symbolic link to the cudnn files in /usr/local/cuda-11.2/include and /usr/local/cuda-11.2/lib64, as follows:

sudo ln -s /usr/include/cudnn*.h /usr/local/cuda-11.2/include
sudo ln -s /usr/lib64/libcudnn*.so /usr/lib64/libcudnn_static.a /usr/local/cuda-11.2/lib64

This unfortunately, does not seem to change anything, so I might just be barking up the wrong tree. Any help?

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:25 (9 by maintainers)

github_iconTop GitHub Comments

10reactions
mattjjcommented, Mar 12, 2021

Installing JAX on Arch has been surprisingly difficult

I thought difficult installations was the whole reason people used Arch! One time in grad school my X11 setup on Arch was broken for months due to a pacman update, so I just learned to work without a graphical interface (until I finally gave up and re-imaged the machine). At least I had the bleeding-edge version of wget though.

Maybe the lesson here is that if JAX installation is hard for an Arch user, it must be really hard… 😄

2reactions
josephroccacommented, Sep 4, 2021

For those that get here via a Google search like I did, and who, while skimming, missed the solution that hawkinsp gave, the solution is to prevent JAX from preallocating too much memory by setting the XLA_PYTHON_CLIENT_MEM_FRACTION environment variable to something lower than 0.9:

$ export XLA_PYTHON_CLIENT_MEM_FRACTION=.7

Or:

>>> import os
>>> os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".7"

My GPU has 8GB of memory (about 6.5GB after OS takes its share), which is still small by normal machine-learning standards, but this still seems more like a bug than a user error. I wonder if the situation here could be improved @hawkinsp? Even just a warning for GPUs that have <10GB of memory, informing the user that they may need to set XLA_PYTHON_CLIENT_MEM_FRACTION to a lower amount to prevent OOM errors?

Read more comments on GitHub >

github_iconTop Results From Across the Web

python-jax - AUR (en) - Arch Linux
I have an error message when try to install python-zipp or python-typing_extensions . I am not sure what could be a nice solution,...
Read more >
Installing JAX - JAX documentation - Read the Docs
JAX is available to install via the Python Package Index. For full installation instructions, please refer to the Install Guide in the project...
Read more >
How to install trax, jax, jaxlib on M1 Mac on macOS 12?
Building XLA and installing it in the jaxlib source tree... ./bazel-4.2.1-darwin-x86_64 run --verbose_failures=true --config=avx_posix --config= ...
Read more >
jax-fdm - PyPI
Consequently, installing JAX FDM on Windows may require a different approach ... You model the arch as a jax_fdm network, apply a force...
Read more >
Spenco FULL ARCH CUSHION INSOLE - JAX
Soft, Comfortable Arch Support Insole Soothes tired feet and sore arches. ... effective June 1, 2019, purchases made online through JAX Mercantile for ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found