question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Adapting SWA for pytorch lightning.

See original GitHub issue

Hello! I’m currently trying to add SWA support to my lightning module, following this post https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/ And I have problems with loading averaged weights after I finish training. I do it like this:

def on_train_end(self): #hook for SWA weights swapping
        if self.swa_enabled:
            print("Swapping model weights to SWA averages.")
            self.trainer.optimizers[0].swap_swa_sgd()
            print("Saving model with SWA weights to another checkpoint")
            self.save_checkpoint(os.path.join(self.trainer.checkpoint_callback.dirpath,f"swa_model_epoch={self.current_checkpoint}.ckpt"))

And bump into following problems:

  1. It returns error message 'MyCoolModel' object has no attribute 'save_checkpoint' on the line self.save_checkpoint....., instead of saving the model as stated in https://pytorch-lightning.readthedocs.io/en/latest/weights_loading.html 2)Even if I succeed with the 1st issue, I’m still required to update BN instances. As I believe, the best way to do this is to just run forward pass over the whole training data. So, my question is: is there a simple way to run 1 epoch though training data without writing my own loops? Thanks in advance

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
jeremyjordancommented, Apr 4, 2020

It returns error message ‘MyCoolModel’ object has no attribute ‘save_checkpoint’ on the line self.save_checkpoint…

i believe the documentation is wrong, save_checkpoint is a method for the Trainer class and not the LightningModule class.

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/trainer/training_io.py#L247

0reactions
jeremyjordancommented, Apr 5, 2020

@RafailFridman feel free to reopen this issue if you’re still having problems

Read more comments on GitHub >

github_iconTop Results From Across the Web

Adapting SWA for pytorch lightning. · Issue #1329 - GitHub
Hello! I'm currently trying to add SWA support to my lightning module, following this post https://pytorch.org/blog/stochastic-weight- ...
Read more >
StochasticWeightAveraging - PyTorch Lightning - Read the Docs
Implements the Stochastic Weight Averaging (SWA) Callback to average a model. Stochastic Weight Averaging was proposed in Averaging Weights Leads to Wider ...
Read more >
Stochastic Weight Averaging in PyTorch
SWA is a simple procedure that improves generalization in deep learning over Stochastic Gradient Descent (SGD) at no additional cost, and can be ......
Read more >
How to implement SWA? - PyTorch Lightning
As per the docs, SWA can be implemented as a callback swa_callback = StochasticWeightAveraging(swa_lrs=5e-4, swa_epoch_start=1) trainer = pl ...
Read more >
PyTorch Lightning Tutorial #2: Using TorchMetrics and ...
We'll also swap out the PyTorch Lightning Trainer object with a Flash Trainer object, which will make it easier to perform transfer learning ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found