Installing JAX on Arch
See original GitHub issueInstalling JAX on Arch has been surprisingly difficult. I’ve been trying to (re)install it for several hours now, after accidentally updating CUDA. While I got it working last time (with CUDA 11.0), what I did then doesn’t work now.
Steps taken to install
I’m working with the AUR repository using yay
:
yay cuda
yay cudnn
It reports the packages are successfully installed, with these versions:
cuda-11.2.1-2
cudnn-8.1.0.77-1
pip
is already fully upgraded (21.0.11
), so now for JAX:
$ sudo pip install --upgrade --force jax jaxlib==0.1.62+cuda112 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Successfully installed jax-0.2.10 jaxlib-0.1.62+cuda112
Because JAX expects CUDA at /usr/local/cuda-XX.X
, but Arch installs CUDA at /opt/cuda
, I create a symbolic link:
sudo ln -s /opt/cuda /usr/local/cuda-11.2
Checking installation
Just as a sanity check, I see if JAX can access devices in a Python shell, which it can:
$ python
>> import jax
>> jax.devices()
[GpuDevice(id=0)]
I then try to run the following example, from Convolutions in JAX:
from jax import numpy as jnp, random, lax
key = random.PRNGKey(1701)
kernel = jnp.zeros((3, 3, 3, 3))
kernel += jnp.array([[1, 1, 0],
[1, 0,-1],
[0,-1,1]])[:, :, jnp.newaxis, jnp.newaxis]
img = jnp.zeros((1, 200, 198, 3))
for k in range(3):
x = 30 + 60 * k
y = 20 + 60 * k
img = img.at[0, x:x+10, y:y+10, k].set(1)
out = lax.conv(jnp.transpose(img, [0,3,1,2]),
jnp.transpose(kernel, [3,2,0,1]),
(1, 1),
'SAME')
I get the following errors.
2021-03-12 04:30:44.451633: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:780] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call = (f32[1,3,200,198]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,3,200,198]{3,2,1,0} %parameter.1, f32[3,3,3,3]{3,2,1,0} %parameter.2), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(1, 3, 200, 198)\n padding=((1, 1), (1, 1))\n precision=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 3, 3)\n window_strides=(1, 1) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.
Convolution performance may be suboptimal.
2021-03-12 04:30:44.548471: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1881] Execution of replica 0 failed: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3294): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
Traceback (most recent call last):
File "/home/kuhlig/Documents/Programming/convolutional-deconvolution/test.py", line 17, in <module>
out = lax.conv(jnp.transpose(img, [0,3,1,2]),
File "/usr/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1582, in conv
return conv_general_dilated(lhs, rhs, window_strides, padding,
File "/usr/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 600, in conv_general_dilated
return conv_general_dilated_p.bind(
File "/usr/lib/python3.9/site-packages/jax/core.py", line 284, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/usr/lib/python3.9/site-packages/jax/core.py", line 622, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/lib/python3.9/site-packages/jax/interpreters/xla.py", line 242, in apply_primitive
return compiled_fun(*args)
File "/usr/lib/python3.9/site-packages/jax/interpreters/xla.py", line 360, in _execute_compiled_primitive
out_bufs = compiled.execute(input_bufs)
RuntimeError: Unknown: CUDNN_STATUS_EXECUTION_FAILED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(3294): 'cudnnConvolutionForward( cudnn.handle(), alpha, input_nd.handle(), input_data.opaque(), filter_nd.handle(), filter_data.opaque(), conv.handle(), ToConvForwardAlgo(algorithm_desc), scratch_memory.opaque(), scratch_memory.size(), beta, output_nd.handle(), output_data.opaque())'
I’ve read through everything I can, and the only suggestion I can find is that jaxlib
, cuda
or cudnn
versions must mismatch. Unfortunately, they don’t seem to:
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Thu_Jan_28_19:32:09_PST_2021
Cuda compilation tools, release 11.2, V11.2.142
Build cuda_11.2.r11.2/compiler.29558016_0
$ whereis cudnn_version
cudnn_version: /usr/include/cudnn_version.h
$ cat /usr/include/cudnn_version.h
...
#define CUDNN_MAJOR 8
#define CUDNN_MINOR 1
#define CUDNN_PATCHLEVEL 0
...
What I did last time (or thought I did – something got it working, and I might be attributing it to the wrong thing), was create a symbolic link to the cudnn
files in /usr/local/cuda-11.2/include
and /usr/local/cuda-11.2/lib64
, as follows:
sudo ln -s /usr/include/cudnn*.h /usr/local/cuda-11.2/include
sudo ln -s /usr/lib64/libcudnn*.so /usr/lib64/libcudnn_static.a /usr/local/cuda-11.2/lib64
This unfortunately, does not seem to change anything, so I might just be barking up the wrong tree. Any help?
Issue Analytics
- State:
- Created 3 years ago
- Comments:25 (9 by maintainers)
Top GitHub Comments
I thought difficult installations was the whole reason people used Arch! One time in grad school my X11 setup on Arch was broken for months due to a pacman update, so I just learned to work without a graphical interface (until I finally gave up and re-imaged the machine). At least I had the bleeding-edge version of wget though.
Maybe the lesson here is that if JAX installation is hard for an Arch user, it must be really hard… 😄
For those that get here via a Google search like I did, and who, while skimming, missed the solution that hawkinsp gave, the solution is to prevent JAX from preallocating too much memory by setting the
XLA_PYTHON_CLIENT_MEM_FRACTION
environment variable to something lower than 0.9:Or:
My GPU has 8GB of memory (about 6.5GB after OS takes its share), which is still small by normal machine-learning standards, but this still seems more like a bug than a user error. I wonder if the situation here could be improved @hawkinsp? Even just a warning for GPUs that have <10GB of memory, informing the user that they may need to set
XLA_PYTHON_CLIENT_MEM_FRACTION
to a lower amount to prevent OOM errors?