question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

JAX/TF2 on multi-GPU silently corrupted computations

See original GitHub issue

When using 1 or 2 GPUs there’s no problem. When using 3 or more GPUs the issue manifests itself resulting in corrupted computations. This was tested on 4 V100 GPUs. There’s a workaround (see commented code).

Code for reproduction:

# The bug only shows for physical GPUs, and only when there 3 or more GPUs.
# It silently corrupts the results with no warning or error message.

import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf

# The following is a workaround from prevents the issue:
# tf.config.experimental.set_visible_devices([], "GPU")

reproduce_bug = True
noop = jax.pmap(lambda x: x)
n = jax.device_count()


def check(delta):
    v = noop(jn.ones((n, 1)) * delta).sum()
    assert v == n * delta, (v, n * delta)


check(1)
if reproduce_bug:
    data = tf.data.Dataset.from_tensor_slices(dict(x=np.random.uniform(-1, 1, size=(16, 1))))

check(2)
# Output is
# Traceback (most recent call last):
#  File "<stdin>", line 1, in <module>
#  File "<stdin>", line 3, in check
# AssertionError: (DeviceArray(7., dtype=float32), 8)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:13 (10 by maintainers)

github_iconTop GitHub Comments

5reactions
hawkinspcommented, Jan 14, 2021

NVidia closed the bug because they were unable to commit engineering resources to look into it.

Closing because there’s no action we can take here. We can’t fix NVidia driver/runtime bugs.

3reactions
hawkinspcommented, Aug 7, 2020

Aha! I think I’ve figured this one out.

JAX and TF both use the stream_executor library inside the TensorFlow tree to interact with CUDA. stream_executor has an optimization where it caches the last GPU context it set in thread-local storage, and skips the call to cuCtxSetCurrent if it thinks the current CUDA context has not changed since last time it was set: https://github.com/tensorflow/tensorflow/blob/001ec7efbed18e9581e859513c5acc76e5aabbe9/tensorflow/stream_executor/cuda/cuda_driver.cc#L204

Inside Google we like linking our binaries statically (since you get hermetic artifacts, at the cost of requiring much more rebuilding), and if we are using JAX and TF together they both share a single copy of stream_executor. So they both use the same thread-local cache.

However, in opensource, we distribute JAX and TF as separate Python plugins, each with their own copy of stream_executor, built with private symbol visibility. This means that rather than having one cache, we have two. They interact badly if you mix TF and JAX in the same binary: TF will change the current GPU, and JAX will fail to notice this, and vice versa. The net effect is that we end up running things on the wrong GPU!

I suspect the right fix is to flush the cache at the Python API boundary, which is certainly something I can do on the JAX side.

TF should probably do the same, given that this isn’t a problem specific to JAX: it could be any GPU-using Python library that does this.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Using JAX in multi-host and multi-process environments
This guide explains how to use JAX in environments such as GPU clusters and Cloud TPU pods where accelerators are spread across multiple...
Read more >
Untitled
Fuat saka lazutlar livera, Concurso ufgd dourados ms, Compute engine free tier, ... Party anthem 2013, Strasilijada, Wd dual actuator technology?
Read more >
Untitled
Appartenenza culturale in inglese, Kether donohue silence of the lambs, Macro revuenon ... Poppy o'hair thronecoming, Somenath biswas theory of computation, ...
Read more >
Untitled
Clock cycle time of a multi cycle processor, Agricultura biologica vantagens e ... Glibc detected double free or corruption java, Download album saffe...
Read more >
Untitled
... Dog sledding in bethel maine, Dual gpu radeon hd 7990. ... Diplomatischer schritt einspruch, Silent running dvdrip french, Different uses of world...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found