TQDMProgressBar refresh forces TPU to recompile compute graph
See original GitHub issueBug description
TQDMProgressBar refresh forces TPU to recompile compute graph. This causes slow execute time.
How to reproduce the bug
import re
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torch_xla.debug.metrics as met
from pytorch_lightning.callbacks import TQDMProgressBar
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.callbacks import Callback
import torch_xla.core.xla_model as xm
class dummyDataset(Dataset):
def __getitem__(self, index):
return torch.rand(512), torch.rand(1)
def __len__(self):
return 10_000
class dummyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(512, 512),
nn.Linear(512, 1),
)
def training_step(self, batch):
x, y = batch
logits = self.net(x)
time.sleep(1)
return {'loss': F.cross_entropy(logits, y)}
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters())
class TPUMetricCallback(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if xm.is_master_ordinal():
report = met.metrics_report()
xrt_compile_count = re.search('Metric: XrtCompile\s+TotalSamples: (\d+)', report).group(1)
print(f'XrtCompile: {xrt_compile_count}, batch_idx = {batch_idx}')
def main():
refresh_rate = 1
ds = dummyDataset()
dl = DataLoader(ds, batch_size=16)
model = dummyModel()
metrics_callback = TPUMetricCallback()
tqdm_callback = TQDMProgressBar(refresh_rate=refresh_rate)
trainer = pl.Trainer(max_epochs=10,
accelerator='tpu',
devices=1,
callbacks=[tqdm_callback, metrics_callback])
trainer.fit(model=model,
train_dataloaders=dl
)
if __name__ == '__main__':
main()
Error messages and logs
Refresh_rate = 1
XrtCompile: 3, batch_idx = 0 Epoch 0: 0%|▎ | 2/625 [00:02<12:42, 1.22s/it, loss=0, v_num=0] XrtCompile: 4, batch_idx = 1 Epoch 0: 0%|▍ | 3/625 [00:03<12:41, 1.22s/it, loss=0, v_num=0] XrtCompile: 5, batch_idx = 2 Epoch 0: 1%|▋ | 4/625 [00:04<12:55, 1.25s/it, loss=0, v_num=0] XrtCompile: 6, batch_idx = 3 Epoch 0: 1%|▊ | 5/625 [00:06<13:22, 1.29s/it, loss=0, v_num=0] XrtCompile: 7, batch_idx = 4 Epoch 0: 1%|▉ | 6/625 [00:08<13:50, 1.34s/it, loss=0, v_num=0] XrtCompile: 8, batch_idx = 5 Epoch 0: 1%|█ | 7/625 [00:09<14:26, 1.40s/it, loss=0, v_num=0] XrtCompile: 9, batch_idx = 6 Epoch 0: 1%|█▎ | 8/625 [00:11<15:02, 1.46s/it, loss=0, v_num=0] XrtCompile: 10, batch_idx = 7 Epoch 0: 1%|█▍
Refresh_rate = 5
Epoch 0: 0%| | 0/625 [00:00<?, ?it/s] XrtCompile: 2, batch_idx = 0 XrtCompile: 2, batch_idx = 1 XrtCompile: 2, batch_idx = 2 XrtCompile: 2, batch_idx = 3 Epoch 0: 1%|▊ | 5/625 [00:05<11:35, 1.12s/it, loss=0, v_num=0] XrtCompile: 3, batch_idx = 4 XrtCompile: 3, batch_idx = 5 XrtCompile: 3, batch_idx = 6 XrtCompile: 3, batch_idx = 7 XrtCompile: 3, batch_idx = 8 Epoch 0: 2%|█▌ | 10/625 [00:12<12:22, 1.21s/it, loss=0, v_num=0] XrtCompile: 4, batch_idx = 9 XrtCompile: 4, batch_idx = 10 XrtCompile: 4, batch_idx = 11 XrtCompile: 4, batch_idx = 12 XrtCompile: 4, batch_idx = 13
Environment
- CUDA: - GPU: None - available: False - version: 11.7
- Lightning: - lightning-utilities: 0.3.0 - pytorch-lightning: 1.8.2 - torch: 1.13.0 - torch-xla: 1.13 - torchmetrics: 0.10.3 - torchvision: 0.14.0
- Packages: - absl-py: 1.3.0 - aiohttp: 3.8.3 - aiosignal: 1.3.1 - antlr4-python3-runtime: 4.9.3 - async-timeout: 4.0.2 - attrs: 19.3.0 - automat: 0.8.0 - blinker: 1.4 - cachetools: 5.2.0 - certifi: 2019.11.28 - chardet: 3.0.4 - charset-normalizer: 2.0.12 - click: 7.0 - cloud-init: 22.1 - cloud-tpu-client: 0.10 - colorama: 0.4.3 - command-not-found: 0.3 - configobj: 5.0.6 - constantly: 15.1.0 - cryptography: 2.8 - cython: 0.29.14 - datasets: 2.7.0 - dbus-python: 1.2.16 - decorator: 5.1.1 - dill: 0.3.6 - distlib: 0.3.4 - distro: 1.4.0 - distro-info: 0.23ubuntu1 - entrypoints: 0.3 - exceptiongroup: 1.0.4 - filelock: 3.7.1 - fire: 0.4.0 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - gcsfs: 2022.11.0 - google-api-core: 1.33.2 - google-api-python-client: 1.8.0 - google-auth: 2.13.0 - google-auth-httplib2: 0.1.0 - google-auth-oauthlib: 0.4.6 - google-cloud-core: 2.3.2 - google-cloud-storage: 2.6.0 - google-crc32c: 1.5.0 - google-resumable-media: 2.4.0 - googleapis-common-protos: 1.56.4 - grpcio: 1.50.0 - httplib2: 0.21.0 - huggingface-hub: 0.11.0 - hydra-core: 1.2.0 - hyperlink: 19.0.0 - idna: 2.8 - importlib-metadata: 5.0.0 - importlib-resources: 5.10.0 - incremental: 16.10.1 - iniconfig: 1.1.1 - intel-openmp: 2022.1.0 - jinja2: 2.10.1 - jsonpatch: 1.22 - jsonpointer: 2.0 - jsonschema: 3.2.0 - keyring: 18.0.1 - language-selector: 0.1 - launchpadlib: 1.10.13 - lazr.restfulclient: 0.14.2 - lazr.uri: 1.0.3 - libtpu-nightly: 0.1.dev20220930 - lightning-utilities: 0.3.0 - markdown: 3.4.1 - markupsafe: 2.1.1 - mkl: 2022.1.0 - mkl-include: 2022.1.0 - more-itertools: 4.2.0 - multidict: 6.0.2 - multiprocess: 0.70.14 - netifaces: 0.10.4 - numpy: 1.23.4 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - oauth2client: 4.1.3 - oauthlib: 3.1.0 - omegaconf: 2.2.3 - packaging: 20.3 - pandas: 1.5.1 - pexpect: 4.6.0 - pillow: 9.3.0 - pip: 20.0.2 - platformdirs: 2.5.2 - pluggy: 1.0.0 - protobuf: 3.19.5 - pyarrow: 10.0.0 - pyasn1: 0.4.2 - pyasn1-modules: 0.2.1 - pygobject: 3.36.0 - pyhamcrest: 1.9.0 - pyjwt: 1.7.1 - pymacaroons: 0.13.0 - pynacl: 1.3.0 - pyopenssl: 19.0.0 - pyparsing: 2.4.6 - pyrsistent: 0.15.5 - pyserial: 3.4 - pytest: 7.2.0 - python-apt: 2.0.0+ubuntu0.20.4.7 - python-dateutil: 2.8.2 - python-debian: 0.1.36ubuntu1 - pytorch-lightning: 1.8.2 - pytz: 2022.6 - pyyaml: 5.4.1 - regex: 2022.10.31 - requests: 2.27.1 - requests-oauthlib: 1.3.1 - requests-unixsocket: 0.2.0 - responses: 0.18.0 - rsa: 4.9 - scipy: 1.9.3 - secretstorage: 2.3.1 - service-identity: 18.1.0 - setuptools: 62.3.2 - simplejson: 3.16.0 - six: 1.14.0 - sos: 4.3 - ssh-import-id: 5.10 - systemd-python: 234 - tbb: 2021.6.0 - tensorboard: 2.11.0 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.5.1 - termcolor: 2.1.0 - tokenizers: 0.13.2 - tomli: 2.0.1 - torch: 1.13.0 - torch-xla: 1.13 - torchmetrics: 0.10.3 - torchvision: 0.14.0 - tqdm: 4.64.1 - transformers: 4.24.0 - twisted: 18.9.0 - typing-extensions: 4.4.0 - ubuntu-advantage-tools: 27.8 - ufw: 0.36 - unattended-upgrades: 0.1 - uritemplate: 3.0.1 - urllib3: 1.25.8 - virtualenv: 20.14.1 - wadllib: 1.3.3 - werkzeug: 2.2.2 - wheel: 0.34.2 - xxhash: 3.1.0 - yarl: 1.8.1 - zipp: 3.10.0 - zope.interface: 4.7.1
- System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.10 - version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022
More info
using TPU-VM TPUv3 on GCP, with tpu-vm-pt-1.13 image, installing this packages using pip
pytorch-lightning torch pytest datasets scipy regex requests numpy fsspec[gcs] hydra-core transformers httplib2>=0.15.0 tensorboardX protobuf==3.19.5
cc @awaelchli
Issue Analytics
- State:
- Created 10 months ago
- Comments:7 (3 by maintainers)
Top GitHub Comments
It could be that this is caused by the metrics collection (the loss etc. we show in the bar), which involves tensor operations.
You might be interested to check #16020 which now handles the
xm.mark_step()
correctly. This might fix your issue.