Deadlock when using AMP with DistributedDataParallel
See original GitHub issueI’ve run into an issue when using AMP with DistributedDataParallel that leads to a deadlock at backward(). I am initializing AMP in the proper order as specified in the docs/example (create models -> move model to GPU -> initialize optimizers -> initialize AMP -> wrap models in DDP) as well as using the with amp.scale_loss(loss_d, optimizer_d) as scaled_loss:
syntax for loss scaling/backward. The model I’m working on is a GAN so I’m initializing AMP with a list of models as well as a list of optimizers and running the with amp.scale_loss()
once for D and again for G (both are referencing their associated optimizers).
The code runs as expected without using amp, as well as when running with AMP on a single GPU (non-distributed). It is only when using mixed-precision and DDP that I run into this issue. As a note, I have the same issue with all mixed-precision option levels.
Some strange behavior presents when using DDP and AMP together beyond just the deadlock itself. My train loop updates D first, so a forward pass through G is performed, and then D loss is defined as loss real + loss fake. The forward pass through D takes a very long time (upwards of a minute on a V100), I get a gradient overflow warning if using dynamic scaling (It says its reducing the loss scaling), then it hangs on backward(). Again, when running the same code without DDP, everything works quite well (The training speed up and memory reduction are awesome).
Another thing that may be relevant is that I’m using gradient penalty in my loss so a call to torch.autograd.grad() is used to compute the GP. This doesn’t present any problems with DDP or AMP when used exclusively.
At the moment I can not provide code samples for reproduction, so understand that this issue doesn’t provide a lot of help in identifying the problem, but any thoughts or suggestions would be very helpful.
Edit: It is also worth mentioning that I am using PyTorch’s DDP, not Apex DDP. I have not tested using Apex DDP yet. Will try that shortly
Issue Analytics
- State:
- Created 5 years ago
- Reactions:1
- Comments:9 (4 by maintainers)
Great! My suspicion is that Amp’s casts were sending tensors to the default device for each process (which, without
set_device
, would have been device 0 for all processes). I think I will continue to recommendset_device
to other people, since callingset_device
is also Pytorch’s official guidance for multiprocess traininghttps://pytorch.org/docs/stable/distributed.html#launch-utility). Also, if Torch DDP is working, I think you should stick with that rather than trying Apex DDP. I don’t think Apex DDP handles double-backward (e.g. gradient penalty) propertly.I wanted to reiterate this because it is unrelated and may also be relevant:
Okay. Calling
set_device()
seems to have fixed this problem!I appreciate your fast reply and your help solving this.