slow pmap allreduce
See original GitHub issueRelated to @joschu’s question about direct access to device arrays, I was curious how fast a pmap allreduce would be as an alternative to trying to use nccl directly on GPU pointers.
This script (https://gist.github.com/christopherhesse/192d78f0f082d66dfb26cac112c5cf99) takes 10,000 ms per loop on 8 V100s, which is surprising to me because nccl-tests’ all_reduce_perf
takes about 5 ms to do what I think is the same operation. Is there an error in my script? I tried using .block_until_ready()
instead of np.array()
but that failed with an exception, so there’s an additional copy to host memory, but even with that it seems like it should be faster.
@jekbradbury commented on a similar issue here: https://github.com/google/jax/issues/606#issuecomment-485063016
I’m using jaxlib 0.1.21 and (I think) jax 1508405ce619e40f43c90f3c34d6af7d0a81ddd5.
Issue Analytics
- State:
- Created 4 years ago
- Comments:7 (5 by maintainers)
Thanks! That should be enough information to investigate the difference in speed if it ends up impacting my application’s performance.
If we end up using NCCL directly then I expect we will not have to copy much data to main memory so this particular issue may not matter as much to me (especially if it is somehow pmap specific).
Hrm not sure about the GCP thing. I can dig more into the config I’m using if that would be helpful, but the basics are:
n1-standard-64 (64 vCPUs, 240 GB memory) in us-west-1b 300GB SSD persistent disk 8 x NVIDIA Tesla V100 “Deep Learning VM” image (maybe it’s called tf-1-13-cu100-20190524) Miniconda (Anaconda) Python 3.7 jax from github master, jaxlib 0.1.21 from pypi
cc @hawkinsp for someone who knows how computers are supposed to work. Any thoughts?