Installing JAX on Arch
See original GitHub issueInstalling JAX on Arch has been surprisingly difficult. I’ve been trying to (re)install it for several hours now, after accidentally updating CUDA. While I got it working last time (with CUDA 11.0), what I did then doesn’t work now.
Steps taken to install
I’m working with the AUR repository using yay
yay cuda
yay cudnn
It reports the packages are successfully installed, with these versions:
is already fully upgraded (21.0.11
), so now for JAX:
$ sudo pip install --upgrade --force jax jaxlib==0.1.62+cuda112 -f
Successfully installed jax-0.2.10 jaxlib-0.1.62+cuda112
Because JAX expects CUDA at /usr/local/cuda-XX.X
, but Arch installs CUDA at /opt/cuda
, I create a symbolic link:
sudo ln -s /opt/cuda /usr/local/cuda-11.2
Checking installation
Just as a sanity check, I see if JAX can access devices in a Python shell, which it can:
$ python
>> import jax
>> jax.devices()
I then try to run the following example, from Convolutions in JAX:
from jax import numpy as jnp, random, lax
key = random.PRNGKey(1701)
kernel = jnp.zeros((3, 3, 3, 3))
kernel += jnp.array([[1, 1, 0],
[1, 0,-1],
[0,-1,1]])[:, :, jnp.newaxis, jnp.newaxis]
img = jnp.zeros((1, 200, 198, 3))
for k in range(3):
x = 30 + 60 * k
y = 20 + 60 * k
img =[0, x:x+10, y:y+10, k].set(1)
out = lax.conv(jnp.transpose(img, [0,3,1,2]),
jnp.transpose(kernel, [3,2,0,1]),
(1, 1),
I get the following errors.
2021-03-12 04:30:44.451633: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call = (f32[1,3,200,198]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,3,200,198]{3,2,1,0} %parameter.1, f32[3,3,3,3]{3,2,1,0} %parameter.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 3, 200, 198)\n padding=((1, 1), (1, 1))\n precision=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 3, 3)\n window_strides=(1, 1) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.
Convolution performance may be suboptimal.
2021-03-12 04:30:44.548471: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/] Execution of replica 0 failed: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/ 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
Traceback (most recent call last):
File "/home/kuhlig/Documents/Programming/convolutional-deconvolution/", line 17, in <module>
out = lax.conv(jnp.transpose(img, [0,3,1,2]),
File "/usr/lib/python3.9/site-packages/jax/_src/lax/", line 1582, in conv
return conv_general_dilated(lhs, rhs, window_strides, padding,
File "/usr/lib/python3.9/site-packages/jax/_src/lax/", line 600, in conv_general_dilated
return conv_general_dilated_p.bind(
File "/usr/lib/python3.9/site-packages/jax/", line 284, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/usr/lib/python3.9/site-packages/jax/", line 622, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/lib/python3.9/site-packages/jax/interpreters/", line 242, in apply_primitive
return compiled_fun(*args)
File "/usr/lib/python3.9/site-packages/jax/interpreters/", line 360, in _execute_compiled_primitive
out_bufs = compiled.execute(input_bufs)
in external/org_tensorflow/tensorflow/stream_executor/cuda/ 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
I’ve read through everything I can, and the only suggestion I can find is that jaxlib
, cuda
or cudnn
versions must mismatch. Unfortunately, they don’t seem to:
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Thu_Jan_28_19:32:09_PST_2021
Cuda compilation tools, release 11.2, V11.2.142
Build cuda_11.2.r11.2/compiler.29558016_0
$ whereis cudnn_version
cudnn_version: /usr/include/cudnn_version.h
$ cat /usr/include/cudnn_version.h
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 1
What I did last time (or thought I did – something got it working, and I might be attributing it to the wrong thing), was create a symbolic link to the cudnn
files in /usr/local/cuda-11.2/include
and /usr/local/cuda-11.2/lib64
, as follows:
sudo ln -s /usr/include/cudnn*.h /usr/local/cuda-11.2/include
sudo ln -s /usr/lib64/libcudnn*.so /usr/lib64/libcudnn_static.a /usr/local/cuda-11.2/lib64
This unfortunately, does not seem to change anything, so I might just be barking up the wrong tree. Any help?
Issue Analytics
- State:
- Created 3 years ago
- Comments:25 (9 by maintainers)
I thought difficult installations was the whole reason people used Arch! One time in grad school my X11 setup on Arch was broken for months due to a pacman update, so I just learned to work without a graphical interface (until I finally gave up and re-imaged the machine). At least I had the bleeding-edge version of wget though.
Maybe the lesson here is that if JAX installation is hard for an Arch user, it must be really hard… 😄
For those that get here via a Google search like I did, and who, while skimming, missed the solution that hawkinsp gave, the solution is to prevent JAX from preallocating too much memory by setting the
environment variable to something lower than 0.9:Or:
My GPU has 8GB of memory (about 6.5GB after OS takes its share), which is still small by normal machine-learning standards, but this still seems more like a bug than a user error. I wonder if the situation here could be improved @hawkinsp? Even just a warning for GPUs that have <10GB of memory, informing the user that they may need to set
to a lower amount to prevent OOM errors?