Wrong values on second GPU when distributing the data SPMD [Tesla K80]
See original GitHub issueDescription
Hi, I’m a newcomer to jax (previously, I used tensorflow). While I was learning jax, I found a strange behaviour when testing the single-process multiple-data (SPMD) programs on my local Tesla K80 (I believe this is a single board with two GPU chips). When I ran the pmap function I found that the data that was stored on (sent to?) the second GPU was different from the first one. To demonstrate the problem I build a simple script around the jax.device_put and jax.device_put_replicated methods (which have the same wrong behaviour as observed in the pmap, I suppose they use the same low-end method for data replication).
To better understand these issues, I made three simple experiences:
- In the first experience, I use jax.device_put_replicated to replicate a two-by-two matrix of ones in both GPUs. However, on the second one, I got a two-by-two matrix with zeros.
Jax visible devices: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
--- Local values ---
gpu:0
[[1. 1.]
[1. 1.]]
--- Distributed values ---
Device: gpu:0
Values:
[[1. 1.]
[1. 1.]]
Device: gpu:1
Values:
[[0. 0.]
[0. 0.]]
- In the second experience, I replace jax.device_put_replicated with jax.device_put (where I select each GPU individually), and as expected, the same wrong behaviour was observed.
--- Second experiment ---
--- Local values ---
gpu:0
[[1. 1.]
[1. 1.]]
--- Distributed values ---
Device: gpu:0
Values:
[[1. 1.]
[1. 1.]]
Device: gpu:1
Values:
[[0. 0.]
[0. 0.]]
- Finally, I tried to execute the code directly on GPU 1 (with help of jit) since by default the values are being allocated to GPU 0. Then made the replication using jax.device_put and jax.device_put_replicated.
--- Third experiment ---
--- Locally create on GPU: 1 ---
gpu:1
[[1. 1.]
[1. 1.]]
--- Distributed values with device_put ---
Device: gpu:0
Values:
[[1. 0.]
[0. 0.]]
Device: gpu:1
Values:
[[1. 1.]
[1. 1.]]
--- Distributed values with device_put_replicated ---
Device: gpu:0
Values:
[[1. 1.]
[1. 1.]]
Device: gpu:1
Values:
[[0. 0.]
[0. 0.]]
As observable, when directly running the instruction on GPU 1, it seems to work. However, when I share it from that GPU (GPU 1) to the other one (GPU 0), the replicated values are still wrong.
Other info, I run the program with this env: CUDA_VISIBLE_DEVICES=0,1. If I run the normal code on each GPU (like CUDA_VISIBLE_DEVICES=1) I did not find any strange behaviour. I also run the code on a desktop with two rtx 2070 and there I had no issues. I also used the same base container to perform these experiences.
The script I run for checking the previously described jax behaviour:
## DEBUG of K80 multi GPU setup
import jax
import jax.numpy as jnp
def first_experiment():
local = jnp.ones((2,2))
print("--- Local values ---")
print(local.device(),"\n", local, "\n\n")
print("--- Distributed values ---")
out = jax.device_put_replicated(local, jax.devices()).block_until_ready()
print("\n\n".join([f"Device: {buf.device()} \nValues:\n {buf}" for buf in out.device_buffers]))
def second_experiment():
local = jnp.ones((2,2))
print("--- Local values ---")
print(local.device(),"\n", local, "\n\n")
print("--- Distributed values ---")
out_0 = jax.device_put(local, jax.devices()[0]).block_until_ready()
# manually select gpu 1
out_1 = jax.device_put(local, jax.devices()[1]).block_until_ready()
print("\n\n".join([f"Device: {out.device()} \nValues:\n {out}" for out in [out_0, out_1]]))
def third_experiment():
print("--- Locally create on GPU: 1 ---")
local = jax.jit(lambda x: jnp.ones((2,2)), device=jax.devices()[1])(0).block_until_ready()
print(local.device(),"\n", local, "\n\n")
# put on gpu 0
print("--- Distributed values with device_put ---")
out = jax.device_put(local, jax.devices()[0]).block_until_ready()
print("\n\n".join([f"Device: {buf.device()} \nValues:\n {buf}" for buf in [out, local]]))
print("\n")
print("--- Distributed values with device_put_replicated ---")
out = jax.device_put_replicated(local, jax.devices()).block_until_ready()
print("\n\n".join([f"Device: {buf.device()} \nValues:\n {buf}" for buf in out.device_buffers]))
if __name__ == "__main__":
print("\nJax visible devices: ", jax.devices(), "\n")
first_experiment()
print("\n--- Second experiment ---\n")
second_experiment()
print("\n--- Third experiment ---\n")
third_experiment()
What jax/jaxlib version are you using?
0.3.15+cuda11.cudnn82
Which accelerator(s) are you using?
GPU - Tesla K80 (single board two GPU chips)
Additional system info
Everything runs inside a container that uses as base image: nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04 | python 3.9.14 | Ubuntu 20.04.4 LTS
NVIDIA GPU info
nvidia-smi
Tue Sep 27 11:00:52 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.00 Driver Version: 470.82.00 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla K80 Off | 00000000:86:00.0 Off | 0 |
| N/A 51C P0 60W / 149W | 0MiB / 11441MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 Tesla K80 Off | 00000000:87:00.0 Off | 0 |
| N/A 41C P0 72W / 149W | 0MiB / 11441MiB | 96% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
nvidia-smi topo -m
GPU0 GPU1 CPU Affinity NUMA Affinity
GPU0 X PIX 10-19,30-39 1
GPU1 PIX X 10-19,30-39 1
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
Issue Analytics
- State:
- Created a year ago
- Comments:7 (2 by maintainers)
I think if the simplep2p program fails, that means that this is almost certainly not a JAX bug. It seems like either a CUDA bug or a hardware failure. Peer-to-peer transfers are exactly what JAX is doing under the hood also!
This is a great tip: I’ll make sure to ask anyone reporting a similar issue in the future to try simpleP2P as well.
Closing since there’s no action we can take. I hope that helps!
It definitely feels like an CUDA/etc issue, since I could not reproduce that with different host with A100 GPU and a newer CUDA version. Though, still don’t understand what is the issue itself