Issues (or typos?) when running JAX code with multiple GPUs
See original GitHub issueHi there. Got two issues when running the JAX code with multiple GPUs:
- https://github.com/deepmind/ferminet/blob/b46077f1a4687b3ac03b583d9726f59ab4d914d9/ferminet/train.py#L293-L297
It would hit
too many values to unpack
error whennum_devices
is greater than 1. My understanding is that we should do
key, *subkeys = jax.random.split(key, num_devices+1)
instead (note the extra asterisk), in which case the following explicit broadcast is not necessary any more for single GPU case.
- https://github.com/deepmind/ferminet/blob/b46077f1a4687b3ac03b583d9726f59ab4d914d9/ferminet/train.py#L372-L373
constants.pmap
gives a tuple of an array instead of just an array in this case whennum_devices
is greater than 1 (not sure why, probably just JAX’s API). This would cause logging to complain. It’s easy to fix though.
Let me know if it makes sense. Also if you like, I can submit a tiny PR to fix them
Issue Analytics
- State:
- Created 3 years ago
- Comments:6
Top Results From Across the Web
Jax multi-gpu randomly hangs forever · Issue #10969 - GitHub
We are facing a problem where a training and validation code based on jax/flax hangs randomly on a multi-gpu host. Using a single...
Read more >Introduction to porting Python to GPU with JAX. - NERSC
Porting the code Kernels were ported from C++ to Numpy to JAX and validated using unit tests. Kernels loop on irregular intervals, we...
Read more >Using JAX in multi-host and multi-process environments
This guide explains how to use JAX in environments such as GPU clusters and Cloud TPU pods where accelerators are spread across multiple...
Read more >Getting started with JAX (MLPs, CNNs & RNNs)
Broadly speaking there are two types of automatic differentiation: ... JAX automatically detects whether you have access to a GPU or TPU.
Read more >[D] Should We Be Using JAX in 2022? : r/MachineLearning
Will JAX's functional paradigm lead to issues for those without functional experience, especially in Deep Learning?
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
That’s the correct thing to do.
pmean
is a collective reduce (ie like MPI_AllReduce instead of MPI_Reduce, if you’re familiar with MPI) and the result is sharded all devices. The logging call requires the data on the host, which transfers the array back, resulting in an array of lengthnum_devices
, with elements, as you say, indentical due to thepmean
. The same thing is done in several places during the main training loop.Fixed in #17. Thanks for spotting this and sending the patch!