Segmentation fault on GPU to GPU communication
See original GitHub issueI get a segmentation fault with some MPI primitives using cuda-enabled mpi. The issue seems to appear when xla is not initialized, as the error disappears if memory is allocated on the GPU before mpi4jax is imported.
Run command used (with gpus on two separate nodes):
MPI4JAX_USE_CUDA_MPI=1 mpiexec -npernode 1 python run.py
Contents of run.py
to reproduce the error:
import mpi4jax
from mpi4py import MPI
comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
rank = comm.Get_rank()
root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)
Error message:
--------------------------------------------------------------------------
Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec noticed that process rank 0 with PID 0 on node 0 exited on signal 11 (Segmentation fault).
--------------------------------------------------------------------------
Workarounds
1: Using MPI4JAX_USE_CUDA_MPI=0
.
2: Importing jax and creating a DeviceArray fixes the problem, but it has to be added before import mpi4jax
. As an example, inserting this in the beginning of run.py
works
import jax.numpy as jnp
jnp.array(3.)
3: Some primitives (only tested mpi4jax.allreduce
) works just fine out of the box. This following piece of code doesn’t crash before the bcast
rank_sum, _ = mpi4jax.allreduce(rank, op=MPI.SUM, comm=comm)
print(rank, rank_sum)
root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)
Versions
Python 3.8.6 OpenMPI 4.0.5-gcccuda-2020b CUDA 11.1.1.GCC-10.2.0 mpi4py 3.1.1 mpi4jax 0.3.2 jax 0.2.21 jaxlib 0.1.71[cuda111]
Issue Analytics
- State:
- Created 2 years ago
- Comments:27
Top Results From Across the Web
Segmentation Fault when using GPU - Google Groups
googlegroups.com. Everything seems to work fine with the CPU but I get seg faults with the GPU. rescomp-12-250088:Project Brett$ python gputest.py.
Read more >Segmentation fault (core dumped) - NVIDIA Developer Forums
“Segmentation fault” indicates the problem is in host code. It is due to access out-of-bounds, or an uninitialized pointer, etc. Use the ...
Read more >Direct mode on NVIDIA 460, 465, 470, 495 and 510 causes ...
Direct mode on NVIDIA 460, 465, 470, 495 and 510 causes segmentation fault. Issue description transferred from internal tracker. On nvidia ...
Read more >cuda 6 unified memory segmentation fault - Stack Overflow
GPU : GTX770; CUDA: 6.0; Driver Version: 331.49. The sample code are taken from the programming guide page 210. __device__ __managed__ int ret[ ......
Read more >Re:OpenMP Segmentation Fault on XE Max GPU
In the other forum link that you have mentioned, the segmentation fault occurs due to the exceeding memory limit on the device side....
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for the help, though, and for the nice library!
I’ve never run it through SLURM, sorry. Hope you could figure it out.