question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Wrong values on second GPU when distributing the data SPMD [Tesla K80]

See original GitHub issue

Description

Hi, I’m a newcomer to jax (previously, I used tensorflow). While I was learning jax, I found a strange behaviour when testing the single-process multiple-data (SPMD) programs on my local Tesla K80 (I believe this is a single board with two GPU chips). When I ran the pmap function I found that the data that was stored on (sent to?) the second GPU was different from the first one. To demonstrate the problem I build a simple script around the jax.device_put and jax.device_put_replicated methods (which have the same wrong behaviour as observed in the pmap, I suppose they use the same low-end method for data replication).

To better understand these issues, I made three simple experiences:

  • In the first experience, I use jax.device_put_replicated to replicate a two-by-two matrix of ones in both GPUs. However, on the second one, I got a two-by-two matrix with zeros.
Jax visible devices:  [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)] 

--- Local values ---
gpu:0 
 [[1. 1.]
 [1. 1.]] 


--- Distributed values ---
Device: gpu:0 
Values:
 [[1. 1.]
 [1. 1.]]

Device: gpu:1 
Values:
 [[0. 0.]
 [0. 0.]]

  • In the second experience, I replace jax.device_put_replicated with jax.device_put (where I select each GPU individually), and as expected, the same wrong behaviour was observed.
--- Second experiment ---

--- Local values ---
gpu:0 
 [[1. 1.]
 [1. 1.]] 


--- Distributed values ---
Device: gpu:0 
Values:
 [[1. 1.]
 [1. 1.]]

Device: gpu:1 
Values:
 [[0. 0.]
 [0. 0.]]

  • Finally, I tried to execute the code directly on GPU 1 (with help of jit) since by default the values are being allocated to GPU 0. Then made the replication using jax.device_put and jax.device_put_replicated.
--- Third experiment ---

--- Locally create on GPU: 1 ---
gpu:1 
 [[1. 1.]
 [1. 1.]] 


--- Distributed values with device_put ---
Device: gpu:0 
Values:
 [[1. 0.]
 [0. 0.]]

Device: gpu:1 
Values:
 [[1. 1.]
 [1. 1.]]


--- Distributed values with device_put_replicated ---
Device: gpu:0 
Values:
 [[1. 1.]
 [1. 1.]]

Device: gpu:1 
Values:
 [[0. 0.]
 [0. 0.]]

As observable, when directly running the instruction on GPU 1, it seems to work. However, when I share it from that GPU (GPU 1) to the other one (GPU 0), the replicated values are still wrong.

Other info, I run the program with this env: CUDA_VISIBLE_DEVICES=0,1. If I run the normal code on each GPU (like CUDA_VISIBLE_DEVICES=1) I did not find any strange behaviour. I also run the code on a desktop with two rtx 2070 and there I had no issues. I also used the same base container to perform these experiences.

The script I run for checking the previously described jax behaviour:


## DEBUG of K80 multi GPU setup
import jax
import jax.numpy as jnp

def first_experiment():
    local = jnp.ones((2,2))
    print("--- Local values ---")
    print(local.device(),"\n", local, "\n\n")
    
    print("--- Distributed values ---")
    out = jax.device_put_replicated(local, jax.devices()).block_until_ready()
    print("\n\n".join([f"Device: {buf.device()} \nValues:\n {buf}" for buf in out.device_buffers]))    
    
def second_experiment():
    local = jnp.ones((2,2))
    print("--- Local values ---")
    print(local.device(),"\n", local, "\n\n")

    print("--- Distributed values ---")
    out_0 = jax.device_put(local, jax.devices()[0]).block_until_ready()
    # manually select gpu 1
    out_1 = jax.device_put(local, jax.devices()[1]).block_until_ready()
    print("\n\n".join([f"Device: {out.device()} \nValues:\n {out}" for out in [out_0, out_1]]))    

def third_experiment():
    print("--- Locally create on GPU: 1 ---")
    local = jax.jit(lambda x: jnp.ones((2,2)), device=jax.devices()[1])(0).block_until_ready()
    print(local.device(),"\n", local, "\n\n")
    
    # put on gpu 0
    print("--- Distributed values with device_put ---")
    out = jax.device_put(local, jax.devices()[0]).block_until_ready()
    print("\n\n".join([f"Device: {buf.device()} \nValues:\n {buf}" for buf in [out, local]])) 
    print("\n")
    print("--- Distributed values with device_put_replicated ---")
    out = jax.device_put_replicated(local, jax.devices()).block_until_ready()
    print("\n\n".join([f"Device: {buf.device()} \nValues:\n {buf}" for buf in out.device_buffers]))  
    

if __name__ == "__main__":
    
    print("\nJax visible devices: ", jax.devices(), "\n")
    
    first_experiment()
    print("\n--- Second experiment ---\n")
    second_experiment()
    
    print("\n--- Third experiment ---\n")
    third_experiment()

What jax/jaxlib version are you using?

0.3.15+cuda11.cudnn82

Which accelerator(s) are you using?

GPU - Tesla K80 (single board two GPU chips)

Additional system info

Everything runs inside a container that uses as base image: nvidia/cuda:11.4.3-cudnn8-devel-ubuntu20.04 | python 3.9.14 | Ubuntu 20.04.4 LTS

NVIDIA GPU info

nvidia-smi

Tue Sep 27 11:00:52 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.00    Driver Version: 470.82.00    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla K80           Off  | 00000000:86:00.0 Off |                    0 |
| N/A   51C    P0    60W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla K80           Off  | 00000000:87:00.0 Off |                    0 |
| N/A   41C    P0    72W / 149W |      0MiB / 11441MiB |     96%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

nvidia-smi topo -m

        GPU0    GPU1    CPU Affinity    NUMA Affinity
GPU0     X      PIX     10-19,30-39     1
GPU1    PIX      X      10-19,30-39     1

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:7 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Sep 27, 2022

I think if the simplep2p program fails, that means that this is almost certainly not a JAX bug. It seems like either a CUDA bug or a hardware failure. Peer-to-peer transfers are exactly what JAX is doing under the hood also!

This is a great tip: I’ll make sure to ask anyone reporting a similar issue in the future to try simpleP2P as well.

Closing since there’s no action we can take. I hope that helps!

0reactions
kefirskicommented, Nov 30, 2022

It definitely feels like an CUDA/etc issue, since I could not reproduce that with different host with A100 GPU and a newer CUDA version. Though, still don’t understand what is the issue itself

Read more comments on GitHub >

github_iconTop Results From Across the Web

K80 crashed or wrong computation results on K80
but I get correct computation results when using GTX 680 while get K80 crashed (maybe memory violation) or obtain wrong computation from K80....
Read more >
Nvidia Tesla K80 not Detected by Computer | Fixed - YouTube
Hello IT Pros, this is Alvendril! In this Video, " Nvidia Tesla K80 not Detected by Computer | Fixed", I show what you...
Read more >
Multi-GPU K80s · Issue #1637 · pytorch/pytorch - GitHub
I'm having trouble getting multi-gpu via DataParallel across two Tesla K80 GPUs. The code I'm using is a modification of the MNIST example:....
Read more >
Performance Evaluation of OpenMP's Target Construct on GPUs
To study potential performance improvements by compiling and optimizing high-level programs for GPU execution, in this paper, we 1) evaluate a set of...
Read more >
Using both GPUs of Nvidia Tesla k80 in one MALTAB code ...
I worked it out and could run my code on my two GPUs. At this time, I have got one NVIDIA GeForce GTX...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found