question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Segmentation fault on GPU to GPU communication

See original GitHub issue

I get a segmentation fault with some MPI primitives using cuda-enabled mpi. The issue seems to appear when xla is not initialized, as the error disappears if memory is allocated on the GPU before mpi4jax is imported.

Run command used (with gpus on two separate nodes):

MPI4JAX_USE_CUDA_MPI=1 mpiexec -npernode 1 python run.py

Contents of run.py to reproduce the error:

import mpi4jax
from mpi4py import MPI

comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
rank = comm.Get_rank()

root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)

Error message:

--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec noticed that process rank 0 with PID 0 on node 0 exited on signal 11 (Segmentation fault).
--------------------------------------------------------------------------

Workarounds

1: Using MPI4JAX_USE_CUDA_MPI=0.

2: Importing jax and creating a DeviceArray fixes the problem, but it has to be added before import mpi4jax. As an example, inserting this in the beginning of run.py works

import jax.numpy as jnp
jnp.array(3.)

3: Some primitives (only tested mpi4jax.allreduce) works just fine out of the box. This following piece of code doesn’t crash before the bcast

rank_sum, _ = mpi4jax.allreduce(rank, op=MPI.SUM, comm=comm)
print(rank, rank_sum)

root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)

Versions

Python 3.8.6 OpenMPI 4.0.5-gcccuda-2020b CUDA 11.1.1.GCC-10.2.0 mpi4py 3.1.1 mpi4jax 0.3.2 jax 0.2.21 jaxlib 0.1.71[cuda111]

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:27

github_iconTop GitHub Comments

1reaction
halvarsucommented, Oct 20, 2021

Thanks for the help, though, and for the nice library!

0reactions
dionhaefnercommented, Jul 11, 2022

I’ve never run it through SLURM, sorry. Hope you could figure it out.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Segmentation Fault when using GPU - Google Groups
googlegroups.com. Everything seems to work fine with the CPU but I get seg faults with the GPU. rescomp-12-250088:Project Brett$ python gputest.py.
Read more >
Segmentation fault (core dumped) - NVIDIA Developer Forums
“Segmentation fault” indicates the problem is in host code. It is due to access out-of-bounds, or an uninitialized pointer, etc. Use the ...
Read more >
Direct mode on NVIDIA 460, 465, 470, 495 and 510 causes ...
Direct mode on NVIDIA 460, 465, 470, 495 and 510 causes segmentation fault. Issue description transferred from internal tracker. On nvidia ...
Read more >
cuda 6 unified memory segmentation fault - Stack Overflow
GPU : GTX770; CUDA: 6.0; Driver Version: 331.49. The sample code are taken from the programming guide page 210. __device__ __managed__ int ret[ ......
Read more >
Re:OpenMP Segmentation Fault on XE Max GPU
In the other forum link that you have mentioned, the segmentation fault occurs due to the exceeding memory limit on the device side....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found