Remove unwrapping logic in strategies in favor of a direct reference to the original module
See original GitHub issueProposed refactor
I propose to remove the fragile unwrapping logic in strategies in favor of a keeping a direct reference to the original module.
Motivation
We currently have an unwrapping logic in our strategies that looks something like this:
def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule":
model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = unwrap_lightning_module(model.module)
if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)):
model = unwrap_lightning_module(model.module)
...
On the strategy, we only store a reference self.model
to the wrapped model, and we additionally have a property
@property
def lightning_module(self):
return unwrap_lightning_module(self.model)
That can return the original LightningModule.
The basic unwrapping logic has been around since PL had support for DP and DDP. It makes little sense to have a function like this, because the user always gives us the pure LightningModule as input to Trainer.fit
. For some reason we kept it and kept extending it, leading to a copy-paste every time a new strategy gets introduced. It also raises a few questions for users who want to build their own strategy. Overall, it’s an unintuitive solution that seems overly complicated when one could just keep the original reference saved in addition to the wrapped model.
In LightningLite, we recently refactored our _LiteModel
wrapper to do exactly this, keeping both the original and wrapped nn Module in order to provide pass-through access to attributes (#12597).
Pitch
- Remove all
unwrap_lightning_module
functions - Modify Strategy.connect() to just save a reference to the given model.
- Make Strategy.lightning_module directly return this reference instead of calling the unwrap method(s)
There are no breaking changes. The result should be equivalent, but simpler, easier to test and thus less fragile.
Additional context
We wanted to do this for a while, came on my radar again after seeing a contributor struggle with it in #13501.
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
-
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
-
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
Issue Analytics
- State:
- Created a year ago
- Reactions:4
- Comments:5 (5 by maintainers)
I like simple 😃 should I close my PR and let you do this?
Great proposal! This should also simplify typing