question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

multi-machine allreduce

See original GitHub issue

Hi! I am looking to do fast multi-machine allreduce and broadcast operations when using JAX and MPI.

Here is a script that should be similar to my workload which I ran on 8 GCE instances with 8 V100 GPUs each and a 32 Gbit network:

https://gist.github.com/christopherhesse/b5d141a59d9648caab191d9ff6333117

I ran it using mpich:

mpiexec -f <hosts file> python <path to script>

The output looks like this:

num_params 1
compute             : min_elapsed 0.000424  avg_elapsed 0.026451  max_elapsed 0.259608
device_to_host      : min_elapsed 0.000070  avg_elapsed 0.000106  max_elapsed 0.000298
allreduce           : min_elapsed 0.000209  avg_elapsed 0.002230  max_elapsed 0.018252
num_params 16000000
compute             : min_elapsed 0.006838  avg_elapsed 0.023782  max_elapsed 0.155499
device_to_host      : min_elapsed 0.123953  avg_elapsed 0.135843  max_elapsed 0.163817
allreduce           : min_elapsed 0.505218  avg_elapsed 0.592024  max_elapsed 0.640469

So about 600 ms per allreduce for 16M float32s.

If I use nccl-tests with MPI support (make MPI=1):

mpiexec -f <hosts file> ./nccl-tests/build/all_reduce_perf -b 1M -e 64M -f 2 -g 1 -c 0

The output looks like this:

[0] #                                                     out-of-place                       in-place          
[0] #       size         count    type   redop     time   algbw   busbw  error     time   algbw   busbw  error
[0] #        (B)    (elements)                     (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
[0]      1048576        262144   float     sum[0]    5328.1    0.20    0.39    N/A[0]    3856.1    0.27    0.54    N/A
[0]      2097152        524288   float     sum[0]    6751.2    0.31    0.61    N/A[0]    6132.7    0.34    0.67    N/A
[0]      4194304       1048576   float     sum[0]     11100    0.38    0.74    N/A[0]     10899    0.38    0.76    N/A
[0]      8388608       2097152   float     sum[0]    9818.9    0.85    1.68    N/A[0]    9351.1    0.90    1.77    N/A
[0]     16777216       4194304   float     sum[0]     17219    0.97    1.92    N/A[0]     17121    0.98    1.93    N/A
[0]     33554432       8388608   float     sum[0]     35836    0.94    1.84    N/A[0]     36609    0.92    1.80    N/A
[0]     67108864      16777216   float     sum[0]     73365    0.91    1.80    N/A[0]     78911    0.85    1.67    N/A

Which looks like 80 ms for 16M float32s.

For my particular training setup, I am seeing ~600 ms spent doing the allreduce, out of ~800 ms total per training loop, so improving this could improve the runtime of my script substantially.

The two ways that seem most promising to me would be:

  1. Use XLA’s existing NCCL support or extend it to do this call through XLA

  2. Use pointers to GPU memory to call NCCL from Python (not sure if this would encounter weird issues with XLA also using CUDA)

What do you guys think?

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:15 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
hawkinspcommented, Jul 22, 2019

The unsafe_buffer_pointer() method should be available in Jaxlib 0.1.22. Please experiment with it (although I wouldn’t consider it a final API.)

1reaction
hawkinspcommented, Aug 12, 2022

I think we can consider this one fixed! JAX supports multiworker GPU computation, and you can perform an all-reduce using pmap and psum.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Massively Scale Your Deep Learning Training with NCCL 2.4
Allreduce operations, used to sum gradients over multiple GPUs, have usually been implemented using rings [1] [2] to achieve full bandwidth. The ...
Read more >
Single machine All Reduce Topology-aware Communication
Theirs is multimachine, but ours is single-machine. Theirs only supports CPU, but ours only supports GPU. In light of these two differences, ...
Read more >
Efficient MPI-Allreduce for Large-Scale Deep Learning on ...
In this paper, we propose two hierarchical distributed memory multi-leader allreduce algorithms optimized for GPU-accelerated clusters (named lr_lr and lr_rab).
Read more >
Allreduce (or MPI) vs. Parameter server approaches
One basic dividing line between parallel approaches is single-machine vs. multi-machine. Multi-machine approaches offer the potential for ...
Read more >
Tree-based Allreduce Communication on MXNet
NVIDIA Collective Communications Library (NCCL) is a vendor-shipped library from NVIDIA [5], which is optimized for single- and multi-machine communication. It ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found