multi-machine allreduce
See original GitHub issueHi! I am looking to do fast multi-machine allreduce and broadcast operations when using JAX and MPI.
Here is a script that should be similar to my workload which I ran on 8 GCE instances with 8 V100 GPUs each and a 32 Gbit network:
https://gist.github.com/christopherhesse/b5d141a59d9648caab191d9ff6333117
I ran it using mpich:
mpiexec -f <hosts file> python <path to script>
The output looks like this:
num_params 1
compute : min_elapsed 0.000424 avg_elapsed 0.026451 max_elapsed 0.259608
device_to_host : min_elapsed 0.000070 avg_elapsed 0.000106 max_elapsed 0.000298
allreduce : min_elapsed 0.000209 avg_elapsed 0.002230 max_elapsed 0.018252
num_params 16000000
compute : min_elapsed 0.006838 avg_elapsed 0.023782 max_elapsed 0.155499
device_to_host : min_elapsed 0.123953 avg_elapsed 0.135843 max_elapsed 0.163817
allreduce : min_elapsed 0.505218 avg_elapsed 0.592024 max_elapsed 0.640469
So about 600 ms per allreduce for 16M float32s.
If I use nccl-tests with MPI support (make MPI=1
):
mpiexec -f <hosts file> ./nccl-tests/build/all_reduce_perf -b 1M -e 64M -f 2 -g 1 -c 0
The output looks like this:
[0] # out-of-place in-place
[0] # size count type redop time algbw busbw error time algbw busbw error
[0] # (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s)
[0] 1048576 262144 float sum[0] 5328.1 0.20 0.39 N/A[0] 3856.1 0.27 0.54 N/A
[0] 2097152 524288 float sum[0] 6751.2 0.31 0.61 N/A[0] 6132.7 0.34 0.67 N/A
[0] 4194304 1048576 float sum[0] 11100 0.38 0.74 N/A[0] 10899 0.38 0.76 N/A
[0] 8388608 2097152 float sum[0] 9818.9 0.85 1.68 N/A[0] 9351.1 0.90 1.77 N/A
[0] 16777216 4194304 float sum[0] 17219 0.97 1.92 N/A[0] 17121 0.98 1.93 N/A
[0] 33554432 8388608 float sum[0] 35836 0.94 1.84 N/A[0] 36609 0.92 1.80 N/A
[0] 67108864 16777216 float sum[0] 73365 0.91 1.80 N/A[0] 78911 0.85 1.67 N/A
Which looks like 80 ms for 16M float32s.
For my particular training setup, I am seeing ~600 ms spent doing the allreduce, out of ~800 ms total per training loop, so improving this could improve the runtime of my script substantially.
The two ways that seem most promising to me would be:
-
Use XLA’s existing NCCL support or extend it to do this call through XLA
-
Use pointers to GPU memory to call NCCL from Python (not sure if this would encounter weird issues with XLA also using CUDA)
What do you guys think?
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:15 (6 by maintainers)
The
unsafe_buffer_pointer()
method should be available in Jaxlib 0.1.22. Please experiment with it (although I wouldn’t consider it a final API.)I think we can consider this one fixed! JAX supports multiworker GPU computation, and you can perform an all-reduce using
pmap
andpsum
.