JAX/TF2 on multi-GPU silently corrupted computations
See original GitHub issueWhen using 1 or 2 GPUs there’s no problem. When using 3 or more GPUs the issue manifests itself resulting in corrupted computations. This was tested on 4 V100 GPUs. There’s a workaround (see commented code).
Code for reproduction:
# The bug only shows for physical GPUs, and only when there 3 or more GPUs.
# It silently corrupts the results with no warning or error message.
import jax
import jax.numpy as jn
import numpy as np
import tensorflow as tf
# The following is a workaround from prevents the issue:
# tf.config.experimental.set_visible_devices([], "GPU")
reproduce_bug = True
noop = jax.pmap(lambda x: x)
n = jax.device_count()
def check(delta):
v = noop(jn.ones((n, 1)) * delta).sum()
assert v == n * delta, (v, n * delta)
check(1)
if reproduce_bug:
data = tf.data.Dataset.from_tensor_slices(dict(x=np.random.uniform(-1, 1, size=(16, 1))))
check(2)
# Output is
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# File "<stdin>", line 3, in check
# AssertionError: (DeviceArray(7., dtype=float32), 8)
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (10 by maintainers)
Top Results From Across the Web
Using JAX in multi-host and multi-process environments
This guide explains how to use JAX in environments such as GPU clusters and Cloud TPU pods where accelerators are spread across multiple...
Read more >Untitled
Fuat saka lazutlar livera, Concurso ufgd dourados ms, Compute engine free tier, ... Party anthem 2013, Strasilijada, Wd dual actuator technology?
Read more >Untitled
Appartenenza culturale in inglese, Kether donohue silence of the lambs, Macro revuenon ... Poppy o'hair thronecoming, Somenath biswas theory of computation, ...
Read more >Untitled
Clock cycle time of a multi cycle processor, Agricultura biologica vantagens e ... Glibc detected double free or corruption java, Download album saffe...
Read more >Untitled
... Dog sledding in bethel maine, Dual gpu radeon hd 7990. ... Diplomatischer schritt einspruch, Silent running dvdrip french, Different uses of world...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
NVidia closed the bug because they were unable to commit engineering resources to look into it.
Closing because there’s no action we can take here. We can’t fix NVidia driver/runtime bugs.
Aha! I think I’ve figured this one out.
JAX and TF both use the
stream_executor
library inside the TensorFlow tree to interact with CUDA.stream_executor
has an optimization where it caches the last GPU context it set in thread-local storage, and skips the call tocuCtxSetCurrent
if it thinks the current CUDA context has not changed since last time it was set: https://github.com/tensorflow/tensorflow/blob/001ec7efbed18e9581e859513c5acc76e5aabbe9/tensorflow/stream_executor/cuda/cuda_driver.cc#L204Inside Google we like linking our binaries statically (since you get hermetic artifacts, at the cost of requiring much more rebuilding), and if we are using JAX and TF together they both share a single copy of
stream_executor
. So they both use the same thread-local cache.However, in opensource, we distribute JAX and TF as separate Python plugins, each with their own copy of
stream_executor
, built with private symbol visibility. This means that rather than having one cache, we have two. They interact badly if you mix TF and JAX in the same binary: TF will change the current GPU, and JAX will fail to notice this, and vice versa. The net effect is that we end up running things on the wrong GPU!I suspect the right fix is to flush the cache at the Python API boundary, which is certainly something I can do on the JAX side.
TF should probably do the same, given that this isn’t a problem specific to JAX: it could be any GPU-using Python library that does this.