question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

DDP training on TPU is not working

See original GitHub issue

Bug description

PyTorch-Lightning DDP training on TPU might be broken for “automatic_optimization” mode.

“automatic_optimization” mode is the default mode for a LightningModule to be trained by the Trainer. Users only need to define the training forward pass in the module and the Trainer would automatically do the backward pass and optimization step correctly.

On TPU (PT/XLA), DDP is usually achieved by calling xm.optimizer_step() for optimization step, which would add a gradient all_reduce op before calling optimizer.step(). In PyTorch-Lightning, it is done in TPUPrecisionPlugin class as:

closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})

However, in “automatic_optimization” mode, PyTorch-Lightning actually puts the forward pass and backward pass into the closure callable given to the optimizer.step(), so forward and backward happen within optimizer.step(). What ended up happening is, in an iteration of a batch:

  1. All_reduce gradients
  2. Forward pass
  3. Zero gradients
  4. Backward pass
  5. Optimizer step

The all_reduce would be a no_op, and the gradients are not being synchronized between DDP processes before being applied to the models. What should happen instead is:

  1. Forward pass
  2. Zero gradients
  3. Backward pass
  4. All_reduce gradients
  5. Optimizer step

The gradient all_reduce op needs to be inserted into the closure. A possible fix could be:

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..2dda36d3f 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ class TPUPrecisionPlugin(PrecisionPlugin):
            raise ModuleNotFoundError(str(_XLA_AVAILABLE))
        super().__init__(*args, **kwargs)
+    def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any:
+        import torch_xla.core.xla_model as xm
+
+        closure_result = closure()
+        xm.reduce_gradients(optimizer)
+        return closure_result
+
    def optimizer_step(  # type: ignore[override]
        self,
        optimizer: Optimizable,
@@ -39,8 +46,9 @@ class TPUPrecisionPlugin(PrecisionPlugin):
    ) -> Any:
        import torch_xla.core.xla_model as xm
+        closure = partial(self._tpu_wrap_closure, optimizer, closure)
        closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
-        closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+        closure_result = optimizer.step(closure=closure, **kwargs)
        skipped_backward = closure_result is None
        # in manual optimization, the closure does not return a value
        if model.automatic_optimization and skipped_backward:

A comparison between with and without the patch is done on the modified MNIST TPU tutorial, with BATCH_SIZE = 128: Tensorboard log. From the training loss curve in it, the model converges faster with the fixed codes, as the gradients are correctly reduced and models are synced across processes.

IR graphs can also be dumped before and after the fix. Before the fix, a xla::cross_replica_sum() op cannot be found in an iteration, while after the fix, it correctly appears after the backward ops.

How to reproduce the bug

MNIST TPU tutorial

Error messages and logs


# Error messages and logs here please

Environment


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): PyTorch-Lightning
#- PyTorch Lightning Version (e.g., 1.5.0): master
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10): 1.14
#- Python version (e.g., 3.9):
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: None
#- TPU models and configuration: Google Cloud TPU V3-8
#- How you installed Lightning(`conda`, `pip`, source): source
#- Running environment of LightningApp (e.g. local, cloud): cloud

More info

No response

Issue Analytics

  • State:open
  • Created 10 months ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
carmoccacommented, Dec 12, 2022

AFAIK Lite does not suffer from this issue because it doesn’t need to run the closure separately. This fix is only relevant for the src/pytorch_lightning/plugins/precision/tpu.py file. You can get started on a PR and we’ll continue reviewing there.

0reactions
Liyang90commented, Dec 1, 2022

Thank you so much! Do you want to open a PR implementing your solution?

I notice there are other places that have called xm.optimizer_step(), such as in lightning, lightning_lite, and pytorch_lightning. I don’t know enough how to incorporate this change to all the places that have similar issue. So I hope someone know more about the package structure could help.

Read more comments on GitHub >

github_iconTop Results From Across the Web

TPU training (Intermediate) - PyTorch Lightning
Lightning automatically inserts the correct samplers - no need to do this yourself! Usually, with TPUs (and DDP), you would need to define...
Read more >
Using TPU appears to stuck at second step in training ...
The code works well on single GPU say in Colab. However when using TPUs it is able to go through first step in...
Read more >
Error running on TPU google colab - PyTorch Lightning
There are cases in which it is NOT possible to use DDP. Examples are: Jupyter Notebook, Google COLAB, Kaggle, etc. You have a...
Read more >
Efficient Training on Multiple GPUs - Hugging Face
Under DP gpu 0 performs a lot more work than the rest of the gpus, ... You can use DDP across multiple machines,...
Read more >
Troubleshooting TensorFlow - TPU - Google Cloud
Problems connecting to the TPU. Debugging common errors. Reducing memory usage. Improving training speed. Debugging drops in model accuracy ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found