question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TQDMProgressBar refresh forces TPU to recompile compute graph

See original GitHub issue

Bug description

TQDMProgressBar refresh forces TPU to recompile compute graph. This causes slow execute time.

How to reproduce the bug

import re
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import torch_xla.debug.metrics as met
from pytorch_lightning.callbacks import TQDMProgressBar
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.callbacks import Callback

import torch_xla.core.xla_model as xm


class dummyDataset(Dataset):
    def __getitem__(self, index):
        return torch.rand(512), torch.rand(1)

    def __len__(self):
        return 10_000

class dummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.net =  nn.Sequential(
        nn.Linear(512, 512),
        nn.Linear(512, 1),
        )

    def training_step(self, batch):
        x, y = batch
        logits = self.net(x)
        time.sleep(1)
        return {'loss': F.cross_entropy(logits, y)}

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters())

class TPUMetricCallback(Callback):
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if xm.is_master_ordinal():
            report = met.metrics_report()
            xrt_compile_count = re.search('Metric: XrtCompile\s+TotalSamples: (\d+)', report).group(1)
            print(f'XrtCompile: {xrt_compile_count}, batch_idx = {batch_idx}')

def main():
    refresh_rate = 1

    ds = dummyDataset()
    dl = DataLoader(ds, batch_size=16)
    model = dummyModel()

    metrics_callback = TPUMetricCallback()
    tqdm_callback = TQDMProgressBar(refresh_rate=refresh_rate)

    trainer = pl.Trainer(max_epochs=10,
                         accelerator='tpu',
                         devices=1,
                         callbacks=[tqdm_callback, metrics_callback])

    trainer.fit(model=model,
                train_dataloaders=dl
                )


if __name__ == '__main__':
    main()

Error messages and logs

Refresh_rate = 1

XrtCompile: 3, batch_idx = 0 Epoch 0: 0%|▎ | 2/625 [00:02<12:42, 1.22s/it, loss=0, v_num=0] XrtCompile: 4, batch_idx = 1 Epoch 0: 0%|▍ | 3/625 [00:03<12:41, 1.22s/it, loss=0, v_num=0] XrtCompile: 5, batch_idx = 2 Epoch 0: 1%|▋ | 4/625 [00:04<12:55, 1.25s/it, loss=0, v_num=0] XrtCompile: 6, batch_idx = 3 Epoch 0: 1%|▊ | 5/625 [00:06<13:22, 1.29s/it, loss=0, v_num=0] XrtCompile: 7, batch_idx = 4 Epoch 0: 1%|▉ | 6/625 [00:08<13:50, 1.34s/it, loss=0, v_num=0] XrtCompile: 8, batch_idx = 5 Epoch 0: 1%|█ | 7/625 [00:09<14:26, 1.40s/it, loss=0, v_num=0] XrtCompile: 9, batch_idx = 6 Epoch 0: 1%|█▎ | 8/625 [00:11<15:02, 1.46s/it, loss=0, v_num=0] XrtCompile: 10, batch_idx = 7 Epoch 0: 1%|█▍

Refresh_rate = 5

Epoch 0: 0%| | 0/625 [00:00<?, ?it/s] XrtCompile: 2, batch_idx = 0 XrtCompile: 2, batch_idx = 1 XrtCompile: 2, batch_idx = 2 XrtCompile: 2, batch_idx = 3 Epoch 0: 1%|▊ | 5/625 [00:05<11:35, 1.12s/it, loss=0, v_num=0] XrtCompile: 3, batch_idx = 4 XrtCompile: 3, batch_idx = 5 XrtCompile: 3, batch_idx = 6 XrtCompile: 3, batch_idx = 7 XrtCompile: 3, batch_idx = 8 Epoch 0: 2%|█▌ | 10/625 [00:12<12:22, 1.21s/it, loss=0, v_num=0] XrtCompile: 4, batch_idx = 9 XrtCompile: 4, batch_idx = 10 XrtCompile: 4, batch_idx = 11 XrtCompile: 4, batch_idx = 12 XrtCompile: 4, batch_idx = 13

Environment

  • CUDA: - GPU: None - available: False - version: 11.7
  • Lightning: - lightning-utilities: 0.3.0 - pytorch-lightning: 1.8.2 - torch: 1.13.0 - torch-xla: 1.13 - torchmetrics: 0.10.3 - torchvision: 0.14.0
  • Packages: - absl-py: 1.3.0 - aiohttp: 3.8.3 - aiosignal: 1.3.1 - antlr4-python3-runtime: 4.9.3 - async-timeout: 4.0.2 - attrs: 19.3.0 - automat: 0.8.0 - blinker: 1.4 - cachetools: 5.2.0 - certifi: 2019.11.28 - chardet: 3.0.4 - charset-normalizer: 2.0.12 - click: 7.0 - cloud-init: 22.1 - cloud-tpu-client: 0.10 - colorama: 0.4.3 - command-not-found: 0.3 - configobj: 5.0.6 - constantly: 15.1.0 - cryptography: 2.8 - cython: 0.29.14 - datasets: 2.7.0 - dbus-python: 1.2.16 - decorator: 5.1.1 - dill: 0.3.6 - distlib: 0.3.4 - distro: 1.4.0 - distro-info: 0.23ubuntu1 - entrypoints: 0.3 - exceptiongroup: 1.0.4 - filelock: 3.7.1 - fire: 0.4.0 - frozenlist: 1.3.3 - fsspec: 2022.11.0 - gcsfs: 2022.11.0 - google-api-core: 1.33.2 - google-api-python-client: 1.8.0 - google-auth: 2.13.0 - google-auth-httplib2: 0.1.0 - google-auth-oauthlib: 0.4.6 - google-cloud-core: 2.3.2 - google-cloud-storage: 2.6.0 - google-crc32c: 1.5.0 - google-resumable-media: 2.4.0 - googleapis-common-protos: 1.56.4 - grpcio: 1.50.0 - httplib2: 0.21.0 - huggingface-hub: 0.11.0 - hydra-core: 1.2.0 - hyperlink: 19.0.0 - idna: 2.8 - importlib-metadata: 5.0.0 - importlib-resources: 5.10.0 - incremental: 16.10.1 - iniconfig: 1.1.1 - intel-openmp: 2022.1.0 - jinja2: 2.10.1 - jsonpatch: 1.22 - jsonpointer: 2.0 - jsonschema: 3.2.0 - keyring: 18.0.1 - language-selector: 0.1 - launchpadlib: 1.10.13 - lazr.restfulclient: 0.14.2 - lazr.uri: 1.0.3 - libtpu-nightly: 0.1.dev20220930 - lightning-utilities: 0.3.0 - markdown: 3.4.1 - markupsafe: 2.1.1 - mkl: 2022.1.0 - mkl-include: 2022.1.0 - more-itertools: 4.2.0 - multidict: 6.0.2 - multiprocess: 0.70.14 - netifaces: 0.10.4 - numpy: 1.23.4 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - oauth2client: 4.1.3 - oauthlib: 3.1.0 - omegaconf: 2.2.3 - packaging: 20.3 - pandas: 1.5.1 - pexpect: 4.6.0 - pillow: 9.3.0 - pip: 20.0.2 - platformdirs: 2.5.2 - pluggy: 1.0.0 - protobuf: 3.19.5 - pyarrow: 10.0.0 - pyasn1: 0.4.2 - pyasn1-modules: 0.2.1 - pygobject: 3.36.0 - pyhamcrest: 1.9.0 - pyjwt: 1.7.1 - pymacaroons: 0.13.0 - pynacl: 1.3.0 - pyopenssl: 19.0.0 - pyparsing: 2.4.6 - pyrsistent: 0.15.5 - pyserial: 3.4 - pytest: 7.2.0 - python-apt: 2.0.0+ubuntu0.20.4.7 - python-dateutil: 2.8.2 - python-debian: 0.1.36ubuntu1 - pytorch-lightning: 1.8.2 - pytz: 2022.6 - pyyaml: 5.4.1 - regex: 2022.10.31 - requests: 2.27.1 - requests-oauthlib: 1.3.1 - requests-unixsocket: 0.2.0 - responses: 0.18.0 - rsa: 4.9 - scipy: 1.9.3 - secretstorage: 2.3.1 - service-identity: 18.1.0 - setuptools: 62.3.2 - simplejson: 3.16.0 - six: 1.14.0 - sos: 4.3 - ssh-import-id: 5.10 - systemd-python: 234 - tbb: 2021.6.0 - tensorboard: 2.11.0 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.5.1 - termcolor: 2.1.0 - tokenizers: 0.13.2 - tomli: 2.0.1 - torch: 1.13.0 - torch-xla: 1.13 - torchmetrics: 0.10.3 - torchvision: 0.14.0 - tqdm: 4.64.1 - transformers: 4.24.0 - twisted: 18.9.0 - typing-extensions: 4.4.0 - ubuntu-advantage-tools: 27.8 - ufw: 0.36 - unattended-upgrades: 0.1 - uritemplate: 3.0.1 - urllib3: 1.25.8 - virtualenv: 20.14.1 - wadllib: 1.3.3 - werkzeug: 2.2.2 - wheel: 0.34.2 - xxhash: 3.1.0 - yarl: 1.8.1 - zipp: 3.10.0 - zope.interface: 4.7.1
  • System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.10 - version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022

More info

using TPU-VM TPUv3 on GCP, with tpu-vm-pt-1.13 image, installing this packages using pip

pytorch-lightning torch pytest datasets scipy regex requests numpy fsspec[gcs] hydra-core transformers httplib2>=0.15.0 tensorboardX protobuf==3.19.5

cc @awaelchli

Issue Analytics

  • State:open
  • Created 10 months ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
awaelchlicommented, Nov 20, 2022

It could be that this is caused by the metrics collection (the loss etc. we show in the bar), which involves tensor operations.

0reactions
awaelchlicommented, Dec 14, 2022

You might be interested to check #16020 which now handles the xm.mark_step() correctly. This might fix your issue.

Read more comments on GitHub >

github_iconTop Results From Across the Web

TQDMProgressBar refresh forces TPU to recompile ... - GitHub
Build and train PyTorch models and connect them to the ML lifecycle using Lightning App templates, without handling DIY infrastructure, cost management, ...
Read more >
TQDMProgressBar — PyTorch Lightning 1.8.5 documentation
This is the default progress bar used by Lightning. It prints to stdout using the tqdm package and shows up to four different...
Read more >
python remove first character from string if match Code Example
s = "hello" print s[1:]
Read more >
News Cabra QPv - Condomz
Durham crisis response center, 2d picture to 3d object, Dog years chart human ... China touch phone price list, M group australia, 90s...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found