seg fault happens when running CPU code using GPU-supported jaxlib
See original GitHub issueThe following repro script, which is taken from #1239,
import jax
from jax.config import config; config.update('jax_platform_name', 'cpu')
import jax.numpy as np
from jax import random, lax, jit
def welford_covariance():
def init_fn(size):
return np.zeros(size), np.zeros(size), 0
def update_fn(sample, state):
mean, m2, n = state
n = n + 1
delta_pre = sample - mean
mean = mean + delta_pre / n
delta_post = sample - mean
m2 = m2 + delta_pre * delta_post
return mean, m2, n
def final_fn(state):
mean, m2, n = state
cov = m2 / (n - 1)
cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
return cov, cov_inv_sqrt
return init_fn, update_fn, final_fn
def warmup_adapter():
mm_init, mm_update, mm_final = welford_covariance()
def init_fn(z, rng, mass_matrix_size):
inverse_mass_matrix = np.ones(mass_matrix_size)
mass_matrix_sqrt = inverse_mass_matrix
mm_state = mm_init(mass_matrix_size)
return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
def _update_at_window_end(z, rng_ss, state):
inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state)
mm_state = mm_init(inverse_mass_matrix.shape[-1])
return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
def update_fn(t, accept_prob, z, state):
inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
rng, rng_ss = random.split(rng)
state = (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
state = lax.cond(t < 10,
(z, rng_ss, state), lambda args: _update_at_window_end(*args),
state, lambda x: x)
return state
return init_fn, update_fn
wa_init, wa_update = warmup_adapter()
wa_update = jit(wa_update) # uncomment this will make it fast
z = np.ones(3)
wa_state = wa_init(z, random.PRNGKey(0), mass_matrix_size=3)
import time
for t in range(10):
tic = time.time()
wa_state = wa_update(t, 0.1 * t, z, wa_state)
print(time.time() - tic)
causes
[mutex.cc : 419] RAW: Lock blocking 0x7fbae0006640 @ 0x7fbbc9577eaf 0x7fbbc95785a6 0x7fbbc6a81236 0x7fbbc9576913 0x7fbbc9463f12 0x7fbbf463c421
[mutex.cc : 419] RAW: Unlock 0x7fbae0006640 @ 0x7fbbc6a80ae2 0x7fbbc9578088 0x7fbbc95785a6 0x7fbbc6a81236 0x7fbbc9576913 0x7fbbc9463f12 0x7fbbf463c421
using https://storage.googleapis.com/jax-releases/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl.
This issue just happens after the recent refactoring of JAX (related to no more tuple). The code runs fine if we change config.update('jax_platform_name', 'cpu')
to config.update('jax_platform_name', 'gpu')
or if we use jaxlib from pypi.
Issue Analytics
- State:
- Created 4 years ago
- Comments:11 (11 by maintainers)
Top Results From Across the Web
Code runs fine on CPU and GPU but gives seg fault at the end ...
Hi, I am trying to write a simple NN example in libtorch and I was able to successfully run this code on a...
Read more >Segmentation fault 11 and other errors when trying to use JAX ...
After conda failed, I tried manually retrieving the latest versions using python -m pip install jax==0.3.25 jaxlib==0.3.25 (in the base conda ...
Read more >Device access segmentation fault - NVIDIA Developer Forums
Hi, I'm beginning to explore porting an existing OMP parallelized modeling code to my RTX 3050. However, I have come across the sort...
Read more >How to debug segfault in an OpenCL kernel?
I added the "-g -s /path/to/kernel.cl" in the clbuildprogram options, and tried to run this in gdb, but every time gdb gives me...
Read more >How does a Segmentation Fault work under-the-hood?
The next time the process tries to use it, the CPU will generate a ... Your crsh doesn't include code to do that,...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I think I spotted it! I had been thinking this was a jaxlib issue, but I tried rolling back jax a bit and the segfault went away. I realized it must have been xla.py accidentally mixing memory pointers on one device with those on another, or something like that, so I spotted a bug fixed in 434d175 (another of mine, I’m afraid, from the big rewrite!).
I’m not sure why we weren’t able to repro this internally. I’m working on a CI test we can use for this to avoid a regression, but right now I’m able to verify that this fixes the issue on a GPU cloud VM.
Uploaded jax 0.1.43 on pypi with this fix.