question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

problems installing JAX on a GCP deep learning VM with GPU

See original GitHub issue

I have created a GCP VM with an A100 GPU and this default image: c0-deeplearning-common-cu113-v20211219-debian-10 This is cuda_11.3 , CUDNN 8.2 and Debian 10, python 3.7. I installed JAX thus:

pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html  

Inside python 3.7 I type ‘import jax’ but t I get this error:

 version `GLIBCXX_3.4.26' not found

According to this issue, , I can solve this by first creating a venv and then installing:

python -m venv env
source env/bin/activate
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html  

This partly works, in that I can now ‘import jax’ and run it. However, it fails when I use ‘jax.scan’: In particlar, the code snippet below gives this error:

2022-01-17 19:46:23.259785: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2086] Execution of replica 0 failed: INTERNAL: CustomCall failed: jaxlib/cuda_prng_kernels.cc:30: operation cudaGetLastError() failed: the provided PTX was compiled with an unsupported toolchain.

Here is the code:

import jax
import jax.numpy as jnp

# sample from a Markov chain
init_dist = jnp.array([0.8, 0.2])
trans_mat = jnp.array([[0.9, 0.1], [0.5, 0.5]])
rng_key = jax.random.PRNGKey(0)
from jax.scipy.special import logit
seq_len = 15

initial_state = jax.random.categorical(rng_key, logits=logit(init_dist), shape=(1,))

def draw_state(prev_state, key):
    logits = logit(trans_mat[:, prev_state])
    state = jax.random.categorical(key, logits=logits.flatten(), shape=(1,))
    return state, state

rng_key, rng_state, rng_obs = jax.random.split(rng_key, 3)
keys = jax.random.split(rng_state, seq_len - 1)

final_state, states = jax.lax.scan(draw_state, initial_state, keys)

print(states)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:3
  • Comments:17 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Oct 27, 2022

The CUDA driver version incompatibility problem should be fixed in the next jaxlib release. JAX will automatically fall back to not using parallel compilation if the NVIDIA driver is too old.

Unfortunately we had to revert the workaround for version GLIBCXX_3.4.26 not found because the workaround was to import scipy ourselves, but that turns out to be too slow to do every time jax is imported. If you still see that problem, I recommend one of the workarounds above. Note that jax is also available via conda-forge (https://github.com/google/jax#conda-installation) and using the conda installation of JAX will not have this issue.

@sayakpaul the issue is you have CUDA 11.0 installed. JAX doesn’t support CUDA 11.0. Install a newer CUDA.

1reaction
murphykcommented, Jan 18, 2022

On Tue, Jan 18, 2022 at 6:12 AM Peter Hawkins @.***> wrote:

The GLIBCXX version issue was for the last user a scipy issue. There’s not much we can do other than to drop our scipy dependency or make it optional. That may be possible, we’d have to look into it.

“the provided PTX was compiled with an unsupported toolchain” means that the driver version on the VM is too old for the JAX binary. We may need to build with an older CUDA release. Another option would be for JAX to warn if the CUDA release is too old and hint that the user needs to upgrade their driver.

How can we do that? Given that this is what every GCP user (who requesta GPU VM) is going to experience, I think the instructions should be clear, otherwise it will just drive users towards other cloud providers.

— Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/9218#issuecomment-1015449329, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDK6EFPXKNDKNPCHMSQ4ODUWVYMFANCNFSM5MFPIPSQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

You are receiving this because you authored the thread.Message ID: @.***>

Read more comments on GitHub >

github_iconTop Results From Across the Web

Troubleshooting | Deep Learning VM Images - Google Cloud
This page describes problems that can come up when creating Deep Learning VM Images instances, and tells you how to address the problems....
Read more >
Newest 'tpu' Questions - Stack Overflow
I created a TPU VM on GCP. I logged in via ssh and want to install some software. But I get the following...
Read more >
Google's TPU Research Cloud! Free TPU hardware for Deep ...
“Cost” is one of the biggest issues when training Machine learning models including “hardware” and “time”. Huge models need a lot of resources...
Read more >
Data science and machine learning on Cloud AI Platform
AI Platform supports Kubeflow, which lets you build portable ML pipelines that you can run on-premises or on Google Cloud Platform without significant...
Read more >
Lecture 6: MLOps Infrastructure & Tooling
JAX focuses primarily on fast numerical computation with autodiff and GPUs across machine learning use cases (not just deep learning). HuggingFace abstracts ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found