Adapting SWA for pytorch lightning.
See original GitHub issueHello! I’m currently trying to add SWA support to my lightning module, following this post https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/ And I have problems with loading averaged weights after I finish training. I do it like this:
def on_train_end(self): #hook for SWA weights swapping
if self.swa_enabled:
print("Swapping model weights to SWA averages.")
self.trainer.optimizers[0].swap_swa_sgd()
print("Saving model with SWA weights to another checkpoint")
self.save_checkpoint(os.path.join(self.trainer.checkpoint_callback.dirpath,f"swa_model_epoch={self.current_checkpoint}.ckpt"))
And bump into following problems:
- It returns error message
'MyCoolModel' object has no attribute 'save_checkpoint'on the lineself.save_checkpoint....., instead of saving the model as stated in https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html 2)Even if I succeed with the 1st issue, I’m still required to update BN instances. As I believe, the best way to do this is to just run forward pass over the whole training data. So, my question is: is there a simple way to run 1 epoch though training data without writing my own loops? Thanks in advance
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (4 by maintainers)
Top Results From Across the Web
Adapting SWA for pytorch lightning. · Issue #1329 - GitHub
Hello! I'm currently trying to add SWA support to my lightning module, following this post https://pytorch.org/blog/stochastic-weight- ...
Read more >StochasticWeightAveraging - PyTorch Lightning - Read the Docs
Implements the Stochastic Weight Averaging (SWA) Callback to average a model. Stochastic Weight Averaging was proposed in Averaging Weights Leads to Wider ...
Read more >Stochastic Weight Averaging in PyTorch
SWA is a simple procedure that improves generalization in deep learning over Stochastic Gradient Descent (SGD) at no additional cost, and can be ......
Read more >How to implement SWA? - PyTorch Lightning
As per the docs, SWA can be implemented as a callback swa_callback = StochasticWeightAveraging(swa_lrs=5e-4, swa_epoch_start=1) trainer = pl ...
Read more >PyTorch Lightning Tutorial #2: Using TorchMetrics and ...
We'll also swap out the PyTorch Lightning Trainer object with a Flash Trainer object, which will make it easier to perform transfer learning ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

i believe the documentation is wrong,
save_checkpointis a method for theTrainerclass and not theLightningModuleclass.https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_io.py#L247
@RafailFridman feel free to reopen this issue if you’re still having problems