question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

seg fault happens when running CPU code using GPU-supported jaxlib

See original GitHub issue

The following repro script, which is taken from #1239,

import jax
from jax.config import config; config.update('jax_platform_name', 'cpu')
import jax.numpy as np
from jax import random, lax, jit


def welford_covariance():
    def init_fn(size):
        return np.zeros(size), np.zeros(size), 0

    def update_fn(sample, state):
        mean, m2, n = state
        n = n + 1
        delta_pre = sample - mean
        mean = mean + delta_pre / n
        delta_post = sample - mean
        m2 = m2 + delta_pre * delta_post
        return mean, m2, n

    def final_fn(state):
        mean, m2, n = state
        cov = m2 / (n - 1)
        cov_inv_sqrt = np.sqrt(np.reciprocal(cov))
        return cov, cov_inv_sqrt

    return init_fn, update_fn, final_fn


def warmup_adapter():
    mm_init, mm_update, mm_final = welford_covariance()

    def init_fn(z, rng, mass_matrix_size):
        inverse_mass_matrix = np.ones(mass_matrix_size)
        mass_matrix_sqrt = inverse_mass_matrix
        mm_state = mm_init(mass_matrix_size)
        return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)

    def _update_at_window_end(z, rng_ss, state):
        inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
        inverse_mass_matrix, mass_matrix_sqrt = mm_final(mm_state)
        mm_state = mm_init(inverse_mass_matrix.shape[-1])
        return (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)

    def update_fn(t, accept_prob, z, state):
        inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng = state
        rng, rng_ss = random.split(rng)
        state = (inverse_mass_matrix, mass_matrix_sqrt, mm_state, rng)
        state = lax.cond(t < 10,
                         (z, rng_ss, state), lambda args: _update_at_window_end(*args),
                         state, lambda x: x)
        return state

    return init_fn, update_fn


wa_init, wa_update = warmup_adapter()
wa_update = jit(wa_update)  # uncomment this will make it fast

z = np.ones(3)
wa_state = wa_init(z, random.PRNGKey(0), mass_matrix_size=3)
import time
for t in range(10):
    tic = time.time()
    wa_state = wa_update(t, 0.1 * t, z, wa_state)
    print(time.time() - tic)

causes

[mutex.cc : 419] RAW: Lock blocking 0x7fbae0006640   @ 0x7fbbc9577eaf 0x7fbbc95785a6 0x7fbbc6a81236 0x7fbbc9576913 0x7fbbc9463f12 0x7fbbf463c421
[mutex.cc : 419] RAW: Unlock 0x7fbae0006640   @ 0x7fbbc6a80ae2 0x7fbbc9578088 0x7fbbc95785a6 0x7fbbc6a81236 0x7fbbc9576913 0x7fbbc9463f12 0x7fbbf463c421

using https://storage.googleapis.com/jax-releases/cuda100/jaxlib-0.1.23-cp36-none-linux_x86_64.whl.

This issue just happens after the recent refactoring of JAX (related to no more tuple). The code runs fine if we change config.update('jax_platform_name', 'cpu') to config.update('jax_platform_name', 'gpu') or if we use jaxlib from pypi.

cc @mattjj @skye

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:11 (11 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Aug 25, 2019

I think I spotted it! I had been thinking this was a jaxlib issue, but I tried rolling back jax a bit and the segfault went away. I realized it must have been xla.py accidentally mixing memory pointers on one device with those on another, or something like that, so I spotted a bug fixed in 434d175 (another of mine, I’m afraid, from the big rewrite!).

I’m not sure why we weren’t able to repro this internally. I’m working on a CI test we can use for this to avoid a regression, but right now I’m able to verify that this fixes the issue on a GPU cloud VM.

1reaction
mattjjcommented, Aug 25, 2019

Uploaded jax 0.1.43 on pypi with this fix.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Code runs fine on CPU and GPU but gives seg fault at the end ...
Hi, I am trying to write a simple NN example in libtorch and I was able to successfully run this code on a...
Read more >
Segmentation fault 11 and other errors when trying to use JAX ...
After conda failed, I tried manually retrieving the latest versions using python -m pip install jax==0.3.25 jaxlib==0.3.25 (in the base conda ...
Read more >
Device access segmentation fault - NVIDIA Developer Forums
Hi, I'm beginning to explore porting an existing OMP parallelized modeling code to my RTX 3050. However, I have come across the sort...
Read more >
How to debug segfault in an OpenCL kernel?
I added the "-g -s /path/to/kernel.cl" in the clbuildprogram options, and tried to run this in gdb, but every time gdb gives me...
Read more >
How does a Segmentation Fault work under-the-hood?
The next time the process tries to use it, the CPU will generate a ... Your crsh doesn't include code to do that,...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found