Different versions of PyTorch Lightning provide different results while training on ImageNet
See original GitHub issue🐛 Bug
While training a ResNet18-like CNN on ImageNet, different versions of PyTorch Lightning reach very different accuracies. Minimal code is in this GitHub repository, experiments logged in this wandb project.
To Reproduce
Create 3 different conda environments, where the only difference between them is PyTorch Lightning version. Launch the trainings, each one in its own environment. Details in the README of the GitHub repository I created for this issue.
Expected behavior
- Training using version 1.5.4 reaches 67%/70% validation/train accuracy.
For more details see the wandb run. - Training using version 1.6.5 reaches 44%/41% validation/train accuracy.
For more details see the wandb run. - Training using version 1.7.3 reaches 49%/61% validation/train accuracy.
For more details see the wandb run.
Environment
There are 3 different environment, I’m adding the output of collect_env_details.py on one of them. The other two are almost the same - the only different is the PyTorch Lightning version.
* CUDA:
- GPU:
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- NVIDIA GeForce RTX 2080 Ti
- available: True
- version: 11.3
* Lightning:
- pytorch-lightning: 1.5.4
- torch: 1.12.1
- torchmetrics: 0.9.3
- torchvision: 0.13.1
* Packages:
- absl-py: 1.2.0
- aiohttp: 3.8.1
- aiosignal: 1.2.0
- async-timeout: 4.0.2
- attrs: 22.1.0
- blinker: 1.4
- brotlipy: 0.7.0
- cachetools: 5.2.0
- certifi: 2022.6.15
- cffi: 1.15.1
- charset-normalizer: 2.1.1
- click: 8.1.3
- colorama: 0.4.5
- cryptography: 37.0.4
- cycler: 0.11.0
- docker-pycreds: 0.4.0
- flatten-dict: 0.4.2
- fonttools: 4.37.1
- frozenlist: 1.3.1
- fsspec: 2022.7.1
- future: 0.18.2
- gitdb: 4.0.9
- gitpython: 3.1.27
- google-auth: 2.11.0
- google-auth-oauthlib: 0.4.6
- grpcio: 1.48.0
- idna: 3.3
- importlib-metadata: 4.11.4
- joblib: 1.1.0
- kiwisolver: 1.4.4
- loguru: 0.6.0
- markdown: 3.4.1
- markupsafe: 2.1.1
- matplotlib: 3.5.3
- multidict: 6.0.2
- munkres: 1.1.4
- numpy: 1.23.2
- oauthlib: 3.2.0
- packaging: 21.3
- pandas: 1.4.3
- pathlib2: 2.3.7.post1
- pathtools: 0.1.2
- pillow: 9.2.0
- pip: 22.2.2
- plotly: 5.10.0
- ply: 3.11
- promise: 2.3
- protobuf: 4.21.5
- psutil: 5.9.1
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.7
- pycparser: 2.21
- pydantic: 1.9.2
- pydeprecate: 0.3.1
- pyjwt: 2.4.0
- pyopenssl: 22.0.0
- pyparsing: 3.0.9
- pyqt5: 5.15.7
- pyqt5-sip: 12.11.0
- pysocks: 1.7.1
- python-dateutil: 2.8.2
- pytorch-lightning: 1.5.4
- pytz: 2022.2.1
- pyu2f: 0.1.5
- pyyaml: 6.0
- requests: 2.28.1
- requests-oauthlib: 1.3.1
- rsa: 4.9
- scikit-learn: 1.1.2
- scipy: 1.9.0
- sentry-sdk: 1.9.5
- setproctitle: 1.3.2
- setuptools: 65.3.0
- shortuuid: 1.0.9
- sip: 6.6.2
- six: 1.16.0
- smmap: 3.0.5
- tenacity: 8.0.1
- tensorboard: 2.9.0
- tensorboard-data-server: 0.6.0
- tensorboard-plugin-wit: 1.8.1
- threadpoolctl: 3.1.0
- tikzplotlib: 0.10.1
- toml: 0.10.2
- torch: 1.12.1
- torchmetrics: 0.9.3
- torchvision: 0.13.1
- tornado: 6.2
- tqdm: 4.64.0
- typing-extensions: 4.3.0
- unicodedata2: 14.0.0
- urllib3: 1.26.11
- wandb: 0.13.1
- webcolors: 1.12
- werkzeug: 2.2.2
- wheel: 0.37.1
- yarl: 1.7.2
- zipp: 3.8.1
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.13
- version: #30~18.04.1-Ubuntu SMP Fri Jan 17 06:14:09 UTC 2020
cc @awaelchli
Issue Analytics
- State:
- Created a year ago
- Comments:12 (6 by maintainers)
Top Results From Across the Web
Step-by-step Walk-through - PyTorch Lightning - Read the Docs
The Research. The Model. The lightning module holds all the core research ingredients: The model. The optimizers. The train/ val/ test steps.
Read more >An introduction to PyTorch Lightning with comparisons to ...
In this blogpost, we will be going through an introduction to PL and implement all the cool tricks like - Gradient Accumulation, 16-bit ......
Read more >Training loss is different when using different pytorch version
I tested my network based on gpt2 with various pytorch version but the results are like above. I set same random seed value...
Read more >Transfer Learning Using PyTorch Lightning
Here A and B can be the same deep learning tasks but on a different dataset. Why? The first few hidden layers of...
Read more >Distributed Deep Learning With PyTorch Lightning (Part 1)
Distributed training is a method of scaling models and data to multiple devices for parallel execution. It generally yields a speedup that ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Since we are looking at graphs that plot accuracy, I think it is worth noting that the computation of accuracy here is not correct: https://github.com/AlonNT/pl-reproduce/blob/edcc17603e256ce37e40dc1fb5ef05257becb219/utils.py#L648-L656 For the sake of sanity and evaluation, I recommend using torchmetrics there. I don’t expect the numbers to be drastically different, but I think it is important not to overlook in the context of distributed training here.
Another reason why I mention this is because I looked at the wandb plots you linked and the training loss curves seem to be identical (note, no seed was set which could explain the small variation in the curves). The validation loss is different which leads me to believe that there might be a difference between either 1) self.log behavior 2) distributed sampling between the different versions during validation phase.
To phase them out, I would do the following. In addition to logging with self.log, log directly using
self.logger.experiment.log(...)
(wandb api). This will log on rank 0 only. This will tell us whether the difference comes from the internal logging code or something else.EDIT: Btw, kudos for running this in different environments where only the PL version changes, while keeping everything else the same. This restores a lot of sanity and eliminates tons speculation for unrelated causes!
cc @akihironitta this is an interesting example where benchmarking would help