TPU performance issues and potential fixes
See original GitHub issueBug description
Issue 1
Usually mark_step()
happens at the beginning of the next iteration when the MpDeviceLoader
wrapped dataloader is iterated. However, PyTorch-Lightning may insert multiple callbacks at the end of a batch iteration, such as progress bar refreshing, logging, metrics tracking, running loss updating. Users can also add user-defined end-of-batch callbacks. These callbacks could access lazy tensors’ values and trigger early evaluations (extra compilations and computations). So as an easy fix, we can materialize all lazy tensors after the optimizer step with a xm.mark_step()
call, just before all the callbacks access the tensor values.
On top of original code:
diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..7624a7626 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -40,7 +40,7 @@ class TPUPrecisionPlugin(PrecisionPlugin):
import torch_xla.core.xla_model as xm
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
- closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+ closure_result = xm.optimizer_step(optimizer, barrier=True, optimizer_args={"closure": closure, **kwargs})
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
Here a barrier=True
argument is added to the xm.optimizer_step()
call. This would trigger a mark_step()
after the optimizer step.
On top of changes proposed in https://github.com/Lightning-AI/lightning/issues/15878:
diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..3f1b59059 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ class TPUPrecisionPlugin(PrecisionPlugin):
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(*args, **kwargs)
+ def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any:
+ import torch_xla.core.xla_model as xm
+
+ closure_result = closure()
+ xm.reduce_gradients(optimizer)
+ return closure_result
+
def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
@@ -39,8 +46,10 @@ class TPUPrecisionPlugin(PrecisionPlugin):
) -> Any:
import torch_xla.core.xla_model as xm
+ closure = partial(self._tpu_wrap_closure, optimizer, closure)
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
- closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+ closure_result = optimizer.step(closure=closure, **kwargs)
+ xm.mark_step()
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if model.automatic_optimization and skipped_backward:
Here a xm.mark_step()
call is added after the optimizer step.
Issue 2
diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py
index 59856f12e..a9275a486 100644
--- a/src/pytorch_lightning/trainer/supporters.py
+++ b/src/pytorch_lightning/trainer/supporters.py
@@ -72,6 +72,8 @@ class TensorRunningAccum:
def append(self, x: torch.Tensor) -> None:
"""Add an element to the accumulator."""
+ if x.device.type == "xla":
+ x = x.cpu()
if self.memory is None:
# tradeoff memory for speed by keeping the memory on device
self.memory = torch.zeros(self.window_length, *x.shape, device=x.device, dtype=x.dtype)
This patch moves the running loss tracking to the CPU in case of TPU.
The running loss is tracked in a fixed-length tensor memory
(size is 20 by default). In every iteration, the new loss tensor is inserted to an incrementing index in memory
: self.memory[self.current_idx] = x
. If the running loss is tracked on TPU as a lazy tensor, this in-place update would be a xla::update_slice()
op with a different base_indices
argument in each iteration, and the inserted loss tensor (x
) is a lazy tensor with a huge graph, essentially the graph of the whole forward pass leading to this loss tensor.
During training iterations, the memory
is somehow not considered as a live tensor that needs to be synced and materialized by mark_step()
, so it is not materialized. Then, when the memory
value is finally accessed during teardown()
, all the losses ever inserted to it and their graphs would be replayed.
(Update: the bug in PT/XLA has been fixed recently. The memory
tensor can now be included in the graph being materialized when mark_step()
is called. However the patch is still necessary, because a xla::update_slice()
op with a different base_indices
argument in each iteration would lead to recompilation even though the rest of the graph for real model training work is identical. The patch is also necessary for users on torch_xla < 1.13.)
Given the simple purpose of the running loss tensor, we can trade off more server-to-host communications for much simpler compilations and computations, by sending loss tensor to CPU and track running loss on CPU. With patch for Issue 1, the loss tensor is already materialized at this moment, so it would not trigger early evaluation, and would simply be a server-to-host transfer.
Issue 3
PyTorch-Lightning moves the logged metrics to CPU from TPU according to this line of code. But they are then moved back to TPU (unintended I assume) at several lines below, because self.device
still points to the XLA device. This leads to additional compilation and transfer from server to host when the metrics are accessed. So the patch below keeps the _ResultCollection
object on CPU even though the training module is on TPU, to avoid moving the logged metrics back to TPU.
Issue https://github.com/Lightning-AI/lightning/issues/15743 might be related to this.
diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
index d4c74f306..383125a41 100644
--- a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
+++ b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
@@ -398,7 +398,11 @@ class _ResultCollection(dict):
def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
super().__init__()
self.training = training
- self.device: Optional[Union[str, torch.device]] = device
+ if device:
+ device = torch.device(device)
+ if device.type == "xla":
+ device = torch.device("cpu")
+ self.device: Optional[torch.device] = device
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.dataloader_idx: Optional[int] = None
@@ -635,10 +639,14 @@ class _ResultCollection(dict):
def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
"""Move all data to the given device."""
- self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))
-
if "device" in kwargs:
- self.device = kwargs["device"]
+ device = torch.device(kwargs["device"])
+ if device.type == "xla":
+ kwargs["device"] = "cpu"
+ device = torch.device(kwargs["device"])
+ self.device = device
+
+ self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))
return self
def cpu(self) -> "_ResultCollection":
How to reproduce the bug
No response
Error messages and logs
# Error messages and logs here please
Environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
More info
No response
Issue Analytics
- State:
- Created 10 months ago
- Reactions:2
- Comments:6 (3 by maintainers)
Top GitHub Comments
Thanks @awaelchli , can you also add @Liyang90 to the tpu label notifier?
Thanks @Liyang90! This is great. Can you cc Steven and me in this kind of issue? I didn’t find a good way to subscribe to labels 😄