question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Different versions of PyTorch Lightning provide different results while training on ImageNet

See original GitHub issue

🐛 Bug

While training a ResNet18-like CNN on ImageNet, different versions of PyTorch Lightning reach very different accuracies. Minimal code is in this GitHub repository, experiments logged in this wandb project.

To Reproduce

Create 3 different conda environments, where the only difference between them is PyTorch Lightning version. Launch the trainings, each one in its own environment. Details in the README of the GitHub repository I created for this issue.

Expected behavior

  • Training using version 1.5.4 reaches 67%/70% validation/train accuracy.
    For more details see the wandb run.
  • Training using version 1.6.5 reaches 44%/41% validation/train accuracy.
    For more details see the wandb run.
  • Training using version 1.7.3 reaches 49%/61% validation/train accuracy.
    For more details see the wandb run.

train_accuracy

validation_accuracy

Environment

There are 3 different environment, I’m adding the output of collect_env_details.py on one of them. The other two are almost the same - the only different is the PyTorch Lightning version.

* CUDA:
        - GPU:
                - NVIDIA GeForce RTX 2080 Ti
                - NVIDIA GeForce RTX 2080 Ti
                - NVIDIA GeForce RTX 2080 Ti
                - NVIDIA GeForce RTX 2080 Ti
                - NVIDIA GeForce RTX 2080 Ti
                - NVIDIA GeForce RTX 2080 Ti
                - NVIDIA GeForce RTX 2080 Ti
                - NVIDIA GeForce RTX 2080 Ti
        - available:         True
        - version:           11.3
* Lightning:
        - pytorch-lightning: 1.5.4
        - torch:             1.12.1
        - torchmetrics:      0.9.3
        - torchvision:       0.13.1
* Packages:
        - absl-py:           1.2.0
        - aiohttp:           3.8.1
        - aiosignal:         1.2.0
        - async-timeout:     4.0.2
        - attrs:             22.1.0
        - blinker:           1.4
        - brotlipy:          0.7.0
        - cachetools:        5.2.0
        - certifi:           2022.6.15
        - cffi:              1.15.1
        - charset-normalizer: 2.1.1
        - click:             8.1.3
        - colorama:          0.4.5
        - cryptography:      37.0.4
        - cycler:            0.11.0
        - docker-pycreds:    0.4.0
        - flatten-dict:      0.4.2
        - fonttools:         4.37.1
        - frozenlist:        1.3.1
        - fsspec:            2022.7.1
        - future:            0.18.2
        - gitdb:             4.0.9
        - gitpython:         3.1.27
        - google-auth:       2.11.0
        - google-auth-oauthlib: 0.4.6
        - grpcio:            1.48.0
        - idna:              3.3
        - importlib-metadata: 4.11.4
        - joblib:            1.1.0
        - kiwisolver:        1.4.4
        - loguru:            0.6.0
        - markdown:          3.4.1
        - markupsafe:        2.1.1
        - matplotlib:        3.5.3
        - multidict:         6.0.2
        - munkres:           1.1.4
        - numpy:             1.23.2
        - oauthlib:          3.2.0
        - packaging:         21.3
        - pandas:            1.4.3
        - pathlib2:          2.3.7.post1
        - pathtools:         0.1.2
        - pillow:            9.2.0
        - pip:               22.2.2
        - plotly:            5.10.0
        - ply:               3.11
        - promise:           2.3
        - protobuf:          4.21.5
        - psutil:            5.9.1
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.7
        - pycparser:         2.21
        - pydantic:          1.9.2
        - pydeprecate:       0.3.1
        - pyjwt:             2.4.0
        - pyopenssl:         22.0.0
        - pyparsing:         3.0.9
        - pyqt5:             5.15.7
        - pyqt5-sip:         12.11.0
        - pysocks:           1.7.1
        - python-dateutil:   2.8.2
        - pytorch-lightning: 1.5.4
        - pytz:              2022.2.1
        - pyu2f:             0.1.5
        - pyyaml:            6.0
        - requests:          2.28.1
        - requests-oauthlib: 1.3.1
        - rsa:               4.9
        - scikit-learn:      1.1.2
        - scipy:             1.9.0
        - sentry-sdk:        1.9.5
        - setproctitle:      1.3.2
        - setuptools:        65.3.0
        - shortuuid:         1.0.9
        - sip:               6.6.2
        - six:               1.16.0
        - smmap:             3.0.5
        - tenacity:          8.0.1
        - tensorboard:       2.9.0
        - tensorboard-data-server: 0.6.0
        - tensorboard-plugin-wit: 1.8.1
        - threadpoolctl:     3.1.0
        - tikzplotlib:       0.10.1
        - toml:              0.10.2
        - torch:             1.12.1
        - torchmetrics:      0.9.3
        - torchvision:       0.13.1
        - tornado:           6.2
        - tqdm:              4.64.0
        - typing-extensions: 4.3.0
        - unicodedata2:      14.0.0
        - urllib3:           1.26.11
        - wandb:             0.13.1
        - webcolors:         1.12
        - werkzeug:          2.2.2
        - wheel:             0.37.1
        - yarl:              1.7.2
        - zipp:              3.8.1
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.9.13
        - version:           #30~18.04.1-Ubuntu SMP Fri Jan 17 06:14:09 UTC 2020

cc @awaelchli

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:12 (6 by maintainers)

github_iconTop GitHub Comments

3reactions
awaelchlicommented, Aug 30, 2022

Since we are looking at graphs that plot accuracy, I think it is worth noting that the computation of accuracy here is not correct: https://github.com/AlonNT/pl-reproduce/blob/edcc17603e256ce37e40dc1fb5ef05257becb219/utils.py#L648-L656 For the sake of sanity and evaluation, I recommend using torchmetrics there. I don’t expect the numbers to be drastically different, but I think it is important not to overlook in the context of distributed training here.

Another reason why I mention this is because I looked at the wandb plots you linked and the training loss curves seem to be identical (note, no seed was set which could explain the small variation in the curves). The validation loss is different which leads me to believe that there might be a difference between either 1) self.log behavior 2) distributed sampling between the different versions during validation phase.

To phase them out, I would do the following. In addition to logging with self.log, log directly using self.logger.experiment.log(...) (wandb api). This will log on rank 0 only. This will tell us whether the difference comes from the internal logging code or something else.

EDIT: Btw, kudos for running this in different environments where only the PL version changes, while keeping everything else the same. This restores a lot of sanity and eliminates tons speculation for unrelated causes!

1reaction
carmoccacommented, Aug 30, 2022

cc @akihironitta this is an interesting example where benchmarking would help

Read more comments on GitHub >

github_iconTop Results From Across the Web

Step-by-step Walk-through - PyTorch Lightning - Read the Docs
The Research. The Model. The lightning module holds all the core research ingredients: The model. The optimizers. The train/ val/ test steps.
Read more >
An introduction to PyTorch Lightning with comparisons to ...
In this blogpost, we will be going through an introduction to PL and implement all the cool tricks like - Gradient Accumulation, 16-bit ......
Read more >
Training loss is different when using different pytorch version
I tested my network based on gpt2 with various pytorch version but the results are like above. I set same random seed value...
Read more >
Transfer Learning Using PyTorch Lightning
Here A and B can be the same deep learning tasks but on a different dataset. Why? The first few hidden layers of...
Read more >
Distributed Deep Learning With PyTorch Lightning (Part 1)
Distributed training is a method of scaling models and data to multiple devices for parallel execution. It generally yields a speedup that ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found