question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TPU performance issues and potential fixes

See original GitHub issue

Bug description

Issue 1

Usually mark_step() happens at the beginning of the next iteration when the MpDeviceLoader wrapped dataloader is iterated. However, PyTorch-Lightning may insert multiple callbacks at the end of a batch iteration, such as progress bar refreshing, logging, metrics tracking, running loss updating. Users can also add user-defined end-of-batch callbacks. These callbacks could access lazy tensors’ values and trigger early evaluations (extra compilations and computations). So as an easy fix, we can materialize all lazy tensors after the optimizer step with a xm.mark_step() call, just before all the callbacks access the tensor values.

On top of original code:

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..7624a7626 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -40,7 +40,7 @@ class TPUPrecisionPlugin(PrecisionPlugin):
        import torch_xla.core.xla_model as xm
        closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
-        closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+        closure_result = xm.optimizer_step(optimizer, barrier=True, optimizer_args={"closure": closure, **kwargs})
        skipped_backward = closure_result is None
        # in manual optimization, the closure does not return a value
        if model.automatic_optimization and skipped_backward:

Here a barrier=True argument is added to the xm.optimizer_step() call. This would trigger a mark_step() after the optimizer step.

On top of changes proposed in https://github.com/Lightning-AI/lightning/issues/15878:

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..3f1b59059 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ class TPUPrecisionPlugin(PrecisionPlugin):
            raise ModuleNotFoundError(str(_XLA_AVAILABLE))
        super().__init__(*args, **kwargs)
+    def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any:
+        import torch_xla.core.xla_model as xm
+
+        closure_result = closure()
+        xm.reduce_gradients(optimizer)
+        return closure_result
+
    def optimizer_step(  # type: ignore[override]
        self,
        optimizer: Optimizable,
@@ -39,8 +46,10 @@ class TPUPrecisionPlugin(PrecisionPlugin):
    ) -> Any:
        import torch_xla.core.xla_model as xm
+        closure = partial(self._tpu_wrap_closure, optimizer, closure)
        closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
-        closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+        closure_result = optimizer.step(closure=closure, **kwargs)
+        xm.mark_step()
        skipped_backward = closure_result is None
        # in manual optimization, the closure does not return a value
        if model.automatic_optimization and skipped_backward:

Here a xm.mark_step() call is added after the optimizer step.

Issue 2

diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py
index 59856f12e..a9275a486 100644
--- a/src/pytorch_lightning/trainer/supporters.py
+++ b/src/pytorch_lightning/trainer/supporters.py
@@ -72,6 +72,8 @@ class TensorRunningAccum:
    def append(self, x: torch.Tensor) -> None:
        """Add an element to the accumulator."""
+        if x.device.type == "xla":
+            x = x.cpu()
        if self.memory is None:
            # tradeoff memory for speed by keeping the memory on device
            self.memory = torch.zeros(self.window_length, *x.shape, device=x.device, dtype=x.dtype)

This patch moves the running loss tracking to the CPU in case of TPU.

The running loss is tracked in a fixed-length tensor memory (size is 20 by default). In every iteration, the new loss tensor is inserted to an incrementing index in memory: self.memory[self.current_idx] = x. If the running loss is tracked on TPU as a lazy tensor, this in-place update would be a xla::update_slice() op with a different base_indices argument in each iteration, and the inserted loss tensor (x) is a lazy tensor with a huge graph, essentially the graph of the whole forward pass leading to this loss tensor.

During training iterations, the memory is somehow not considered as a live tensor that needs to be synced and materialized by mark_step(), so it is not materialized. Then, when the memory value is finally accessed during teardown(), all the losses ever inserted to it and their graphs would be replayed.

(Update: the bug in PT/XLA has been fixed recently. The memory tensor can now be included in the graph being materialized when mark_step() is called. However the patch is still necessary, because a xla::update_slice() op with a different base_indices argument in each iteration would lead to recompilation even though the rest of the graph for real model training work is identical. The patch is also necessary for users on torch_xla < 1.13.)

Given the simple purpose of the running loss tensor, we can trade off more server-to-host communications for much simpler compilations and computations, by sending loss tensor to CPU and track running loss on CPU. With patch for Issue 1, the loss tensor is already materialized at this moment, so it would not trigger early evaluation, and would simply be a server-to-host transfer.

Issue 3

PyTorch-Lightning moves the logged metrics to CPU from TPU according to this line of code. But they are then moved back to TPU (unintended I assume) at several lines below, because self.device still points to the XLA device. This leads to additional compilation and transfer from server to host when the metrics are accessed. So the patch below keeps the _ResultCollection object on CPU even though the training module is on TPU, to avoid moving the logged metrics back to TPU.

Issue https://github.com/Lightning-AI/lightning/issues/15743 might be related to this.

diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
index d4c74f306..383125a41 100644
--- a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
+++ b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
@@ -398,7 +398,11 @@ class _ResultCollection(dict):
    def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
        super().__init__()
        self.training = training
-        self.device: Optional[Union[str, torch.device]] = device
+        if device:
+            device = torch.device(device)
+            if device.type == "xla":
+                device = torch.device("cpu")
+        self.device: Optional[torch.device] = device
        self.batch: Optional[Any] = None
        self.batch_size: Optional[int] = None
        self.dataloader_idx: Optional[int] = None
@@ -635,10 +639,14 @@ class _ResultCollection(dict):
    def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
        """Move all data to the given device."""
-        self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))
-
        if "device" in kwargs:
-            self.device = kwargs["device"]
+            device = torch.device(kwargs["device"])
+            if device.type == "xla":
+                kwargs["device"] = "cpu"
+                device = torch.device(kwargs["device"])
+            self.device = device
+
+        self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))
        return self
    def cpu(self) -> "_ResultCollection":

How to reproduce the bug

No response

Error messages and logs


# Error messages and logs here please

Environment


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

Issue Analytics

  • State:open
  • Created 10 months ago
  • Reactions:2
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
JackCaoGcommented, Dec 7, 2022

Thanks @awaelchli , can you also add @Liyang90 to the tpu label notifier?

2reactions
JackCaoGcommented, Dec 6, 2022

Thanks @Liyang90! This is great. Can you cc Steven and me in this kind of issue? I didn’t find a good way to subscribe to labels 😄

Read more comments on GitHub >

github_iconTop Results From Across the Web

Cloud TPU performance guide
Using a batch size of 1024 and feature dimensions that are a multiple of 128 results in the best efficiency, although this may...
Read more >
TPU Training. Harnessing the power of dedicated DNN…
You will not be hard pressed to find TPU to GPU performance ... for identifying and fixing potential issues that could be run...
Read more >
Training Large-Scale Recommendation Models with TPUs
TPU offers near-linear scaling performance for a small number of cores, but scaling becomes challenging the higher the core count.
Read more >
Google TPU: Architecture and Performance Best Practices
When Should You Use TPUs? · Models that use vector-wise linear algebra or element-wise algebra (as opposed to matrix calculations) · Models that...
Read more >
How Tensor Processing Units Boost Your ML Computational ...
Traditional processors are starting to reach their limits when it comes to AI and machine learning. Learn how TPUs can help you bypass...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found