Memory blows up when training large models on all TPU cores
See original GitHub issue🐛 Bug
I am training on 8 TPU cores but the memory blows up when the epoch ends.
To Reproduce
Try training a bert large on 8 TPU cores
Expected behavior
Second epoch should get started
Environment
Kaggle TPU
- PyTorch Lightning Version (e.g., 1.5.0):
- PyTorch Version (e.g., 1.10):
- Python version (e.g., 3.9):
- OS (e.g., Linux): Linux
- CUDA/cuDNN version:
- GPU models and configuration:
- How you installed PyTorch (
conda
,pip
, source): - If compiling from source, the output of
torch.__config__.show()
: - Any other relevant information:
Additional context
I will try to use GPU instead of TPUs
cc @kaushikb11 @rohitgr7 @awaelchli @ananthsub @ninginthecloud
Issue Analytics
- State:
- Created a year ago
- Comments:9 (4 by maintainers)
Top Results From Across the Web
Memory blowup with TPU Trainer in master #6873 - GitHub
Recent changes to the Trainer for TPU has resulted in memory blowup during training. On a machine with 208GB of RAM [sic], this...
Read more >Handling big models - Hugging Face
Sharded checkpoints. It's possible your model is so big that even a single copy won't fit in RAM. That doesn't mean it can't...
Read more >Feeding the Beast: The Data Loading Path for Deep Learning ...
Transferring tensors into the GPU memory (CPU). Using parallelism to achieve throughput. A large amount of I/O, medium-high latency per example, and strong ......
Read more >Running out of GPU memory with just 3 samples of ...
Hi, I'm training a model with model.fitDataset. The input dimensions are [480, 640, 3] with just 4 outputs of size [1, 4] and...
Read more >Train With Mixed Precision - NVIDIA Documentation Center
Lowering the required memory enables training of larger models or training ... NVIDIA GPUs offer up to 8x more half precision arithmetic ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Actually I know what causes this issue. Earlier i was directly using
torch/xla
. The memory blows when it tires to save the model at the end on epoch. Like if you would remove the check pointing code this works ok.The issue is in saving large models trained on TPU (multi-core)
Any updates ? Maybe I should close this issue and open a duplicate on
torch/xla
?