DDP training on TPU is not working
See original GitHub issueBug description
PyTorch-Lightning DDP training on TPU might be broken for “automatic_optimization” mode.
“automatic_optimization” mode is the default mode for a LightningModule to be trained by the Trainer. Users only need to define the training forward pass in the module and the Trainer would automatically do the backward pass and optimization step correctly.
On TPU (PT/XLA), DDP is usually achieved by calling xm.optimizer_step()
for optimization step, which would add a gradient all_reduce op before calling optimizer.step()
. In PyTorch-Lightning, it is done in TPUPrecisionPlugin class as:
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
However, in “automatic_optimization” mode, PyTorch-Lightning actually puts the forward pass and backward pass into the closure
callable given to the optimizer.step()
, so forward and backward happen within optimizer.step()
. What ended up happening is, in an iteration of a batch:
- All_reduce gradients
- Forward pass
- Zero gradients
- Backward pass
- Optimizer step
The all_reduce would be a no_op, and the gradients are not being synchronized between DDP processes before being applied to the models. What should happen instead is:
- Forward pass
- Zero gradients
- Backward pass
- All_reduce gradients
- Optimizer step
The gradient all_reduce op needs to be inserted into the closure. A possible fix could be:
diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..2dda36d3f 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ class TPUPrecisionPlugin(PrecisionPlugin):
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)
+ def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any:
+ import torch_xla.core.xla_model as xm
+
+ closure_result = closure()
+ xm.reduce_gradients(optimizer)
+ return closure_result
+
def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
@@ -39,8 +46,9 @@ class TPUPrecisionPlugin(PrecisionPlugin):
) -> Any:
import torch_xla.core.xla_model as xm
+ closure = partial(self._tpu_wrap_closure, optimizer, closure)
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
- closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+ closure_result = optimizer.step(closure=closure, **kwargs)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
A comparison between with and without the patch is done on the modified MNIST TPU tutorial, with BATCH_SIZE = 128
: Tensorboard log. From the training loss curve in it, the model converges faster with the fixed codes, as the gradients are correctly reduced and models are synced across processes.
IR graphs can also be dumped before and after the fix. Before the fix, a xla::cross_replica_sum()
op cannot be found in an iteration, while after the fix, it correctly appears after the backward ops.
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): PyTorch-Lightning
#- PyTorch Lightning Version (e.g., 1.5.0): master
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10): 1.14
#- Python version (e.g., 3.9):
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: None
#- TPU models and configuration: Google Cloud TPU V3-8
#- How you installed Lightning(`conda`, `pip`, source): source
#- Running environment of LightningApp (e.g. local, cloud): cloud
More info
No response
Issue Analytics
- State:
- Created 10 months ago
- Comments:6 (6 by maintainers)
Top GitHub Comments
AFAIK Lite does not suffer from this issue because it doesn’t need to run the closure separately. This fix is only relevant for the
src/pytorch_lightning/plugins/precision/tpu.py
file. You can get started on a PR and we’ll continue reviewing there.I notice there are other places that have called
xm.optimizer_step()
, such as in lightning, lightning_lite, and pytorch_lightning. I don’t know enough how to incorporate this change to all the places that have similar issue. So I hope someone know more about the package structure could help.