problems installing JAX on a GCP deep learning VM with GPU
See original GitHub issueI have created a GCP VM with an A100 GPU and this default image: c0-deeplearning-common-cu113-v20211219-debian-10 This is cuda_11.3 , CUDNN 8.2 and Debian 10, python 3.7. I installed JAX thus:
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
Inside python 3.7 I type ‘import jax’ but t I get this error:
version `GLIBCXX_3.4.26' not found
According to this issue, , I can solve this by first creating a venv and then installing:
python -m venv env
source env/bin/activate
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
This partly works, in that I can now ‘import jax’ and run it. However, it fails when I use ‘jax.scan’: In particlar, the code snippet below gives this error:
2022-01-17 19:46:23.259785: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2086] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: the provided PTX was compiled with an unsupported toolchain.
Here is the code:
import jax
import jax.numpy as jnp
# sample from a Markov chain
init_dist = jnp.array([0.8, 0.2])
trans_mat = jnp.array([[0.9, 0.1], [0.5, 0.5]])
rng_key = jax.random.PRNGKey(0)
from jax.scipy.special import logit
seq_len = 15
initial_state = jax.random.categorical(rng_key, logits=logit(init_dist), shape=(1,))
def draw_state(prev_state, key):
logits = logit(trans_mat[:, prev_state])
state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,))
return state, state
rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)
keys = jax.random.split(rng_state, seq_len - 1)
final_state, states = jax.lax.scan(draw_state, initial_state, keys)
print(states)
Issue Analytics
- State:
- Created 2 years ago
- Reactions:3
- Comments:17 (9 by maintainers)
Top Results From Across the Web
Troubleshooting | Deep Learning VM Images - Google Cloud
This page describes problems that can come up when creating Deep Learning VM Images instances, and tells you how to address the problems....
Read more >Newest 'tpu' Questions - Stack Overflow
I created a TPU VM on GCP. I logged in via ssh and want to install some software. But I get the following...
Read more >Google's TPU Research Cloud! Free TPU hardware for Deep ...
“Cost” is one of the biggest issues when training Machine learning models including “hardware” and “time”. Huge models need a lot of resources...
Read more >Data science and machine learning on Cloud AI Platform
AI Platform supports Kubeflow, which lets you build portable ML pipelines that you can run on-premises or on Google Cloud Platform without significant...
Read more >Lecture 6: MLOps Infrastructure & Tooling
JAX focuses primarily on fast numerical computation with autodiff and GPUs across machine learning use cases (not just deep learning). HuggingFace abstracts ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
The CUDA driver version incompatibility problem should be fixed in the next jaxlib release. JAX will automatically fall back to not using parallel compilation if the NVIDIA driver is too old.
Unfortunately we had to revert the workaround for
version GLIBCXX_3.4.26 not found
because the workaround was to importscipy
ourselves, but that turns out to be too slow to do every time jax is imported. If you still see that problem, I recommend one of the workarounds above. Note thatjax
is also available viaconda-forge
(https://github.com/google/jax#conda-installation) and using theconda
installation of JAX will not have this issue.@sayakpaul the issue is you have CUDA 11.0 installed. JAX doesn’t support CUDA 11.0. Install a newer CUDA.
On Tue, Jan 18, 2022 at 6:12 AM Peter Hawkins @.***> wrote:
How can we do that? Given that this is what every GCP user (who requesta GPU VM) is going to experience, I think the instructions should be clear, otherwise it will just drive users towards other cloud providers.