TQDMProgressBar refresh forces TPU to recompile compute graph
See original GitHub issueBug description
TQDMProgressBar refresh forces TPU to recompile compute graph. This causes slow execute time.
How to reproduce the bug
import re
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torch_xla.debug.metrics as met
from pytorch_lightning.callbacks import TQDMProgressBar
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.callbacks import Callback
import torch_xla.core.xla_model as xm
class dummyDataset(Dataset):
def __getitem__(self, index):
return torch.rand(512), torch.rand(1)
def __len__(self):
return 10_000
class dummyModel(pl.LightningModule):
def __init__(self):
self.net = nn.Sequential(
nn.Linear(512, 512),
nn.Linear(512, 1),
def training_step(self, batch):
x, y = batch
logits = self.net(x)
return {'loss': F.cross_entropy(logits, y)}
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters())
class TPUMetricCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if xm.is_master_ordinal():
report = met.metrics_report()
xrt_compile_count = re.search('Metric: XrtCompile\s+TotalSamples: (\d+)', report).group(1)
print(f'XrtCompile: {xrt_compile_count}, batch_idx = {batch_idx}')
def main():
refresh_rate = 1
ds = dummyDataset()
dl = DataLoader(ds, batch_size=16)
model = dummyModel()
metrics_callback = TPUMetricCallback()
tqdm_callback = TQDMProgressBar(refresh_rate=refresh_rate)
trainer = pl.Trainer(max_epochs=10,
callbacks=[tqdm_callback, metrics_callback])
if __name__ == '__main__':
Error messages and logs
Refresh_rate = 1
XrtCompile: 3, batch_idx = 0 Epoch 0: 0%|▎ | 2/625 [00:02<12:42, 1.22s/it, loss=0, v_num=0] XrtCompile: 4, batch_idx = 1 Epoch 0: 0%|▍ | 3/625 [00:03<12:41, 1.22s/it, loss=0, v_num=0] XrtCompile: 5, batch_idx = 2 Epoch 0: 1%|▋ | 4/625 [00:04<12:55, 1.25s/it, loss=0, v_num=0] XrtCompile: 6, batch_idx = 3 Epoch 0: 1%|▊ | 5/625 [00:06<13:22, 1.29s/it, loss=0, v_num=0] XrtCompile: 7, batch_idx = 4 Epoch 0: 1%|▉ | 6/625 [00:08<13:50, 1.34s/it, loss=0, v_num=0] XrtCompile: 8, batch_idx = 5 Epoch 0: 1%|█ | 7/625 [00:09<14:26, 1.40s/it, loss=0, v_num=0] XrtCompile: 9, batch_idx = 6 Epoch 0: 1%|█▎ | 8/625 [00:11<15:02, 1.46s/it, loss=0, v_num=0] XrtCompile: 10, batch_idx = 7 Epoch 0: 1%|█▍
Refresh_rate = 5
Epoch 0: 0%| | 0/625 [00:00<?, ?it/s] XrtCompile: 2, batch_idx = 0 XrtCompile: 2, batch_idx = 1 XrtCompile: 2, batch_idx = 2 XrtCompile: 2, batch_idx = 3 Epoch 0: 1%|▊ | 5/625 [00:05<11:35, 1.12s/it, loss=0, v_num=0] XrtCompile: 3, batch_idx = 4 XrtCompile: 3, batch_idx = 5 XrtCompile: 3, batch_idx = 6 XrtCompile: 3, batch_idx = 7 XrtCompile: 3, batch_idx = 8 Epoch 0: 2%|█▌ | 10/625 [00:12<12:22, 1.21s/it, loss=0, v_num=0] XrtCompile: 4, batch_idx = 9 XrtCompile: 4, batch_idx = 10 XrtCompile: 4, batch_idx = 11 XrtCompile: 4, batch_idx = 12 XrtCompile: 4, batch_idx = 13
- CUDA: - GPU: None - available: False - version: 11.7
- Lightning: - lightning-utilities: 0.3.0 - pytorch-lightning: 1.8.2 - torch: 1.13.0 - torch-xla: 1.13 - torchmetrics: 0.10.3 - torchvision: 0.14.0
- System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.10 - version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022
More info
using TPU-VM TPUv3 on GCP, with tpu-vm-pt-1.13 image, installing this packages using pip
pytorch-lightning torch pytest datasets scipy regex requests numpy fsspec[gcs] hydra-core transformers httplib2>=0.15.0 tensorboardX protobuf==3.19.5
