Move the module to the precision dtype
See original GitHub issue🐛 Bug
@SeanNaren
When I use bf-16
and check the dtype
of the model, it seems like the model’s precision is fp32
(and I do not see the memory gains I expect). On other frameworks that support bf-16
(like fairseq) the model’s dtype is torch.bfloat16
. Is there a simple example that “proves” that this feature reduces the memory consumption as it should? I suspect that there might be something wrong (but of course, I might be wrong).
Thank you!
To Reproduce
launch any job with precision=bf16
and compare with precision=32
.
Expected behavior
This feature should save 30-50% memory but I do not see such gains in lightning.
Environment
- CUDA:
- GPU:
- GeForce RTX 3090
- available: True
- version: 11.3
- GPU:
- Packages:
- numpy: 1.21.2
- pyTorch_debug: False
- pyTorch_version: 1.11.0
- pytorch-lightning: 1.6.0dev
- tqdm: 4.62.3
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.12
- version: #74-Ubuntu SMP Tue Sep 17 17:06:04 UTC 2019
Additional context
BF-16 is a very important feature. It is usually more stable than fp16 and lightning should support it effectively (models that are pretrained with bf-16 should not be used with fp-16) 😃
cc @borda @tchaton @rohitgr7 @carmocca @justusschock @awaelchli @akihironitta
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:9 (7 by maintainers)
@yuvalkirstain I’m glad it worked! Hopefully we’ll get the feature in soon 😃
@SeanNaren Yes, doing so results in less memory on the GPU with identical results, thank you!
Regarding converting the pl_module internally for users, definitely, it makes more sense IMO that the trainer will take care of that rather than the model.