on_epoch_end callback is called before on_validation_epoch_end
See original GitHub issueπ Bug
The on_epoch_end is called before the epoch ends.
What Iβm doing:
- in the pl model I have the
validation_epoch_end
which computes an accuracy which I log withself.log("val/meta_acc", meta_acc)
- I have a callback with a single method defined as:
class PrintAndSaveCallback(pl.callbacks.Callback):
def on_epoch_end(self, trainer, pl_module):
metrics = trainer.callback_metrics
metrics["val/meta_acc"] ### implemented checkpoint saving here using meta_acc value <-- this now fails with KeyError
- trainer gets this callback as
callbacks = [PrintAndSaveCallback],
Up until I upgraded from 1.0.8 (directly now to 1.2.2 and 1.2.3 today) everything was working fine. The validation_epoch_end was logging metrics and in the callback I read them fine. Now, Iβm getting :
Epoch 0: 79%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 42/53 [00:08<00:02, 4.80it/s, loss=2.85, v_num=2]
Traceback (most recent call last):
File "cube3/trainer.py", line 279, in <module>
trainer_object.fit()
File "cube3/trainer.py", line 233, in fit
trainer.fit(model, train_loader, val_loader)
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 514, in fit
self.dispatch()
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 554, in dispatch
self.accelerator.start_training(self)
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 111, in start_training
self._results = trainer.run_train()
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in run_train
self.train_loop.run_training_epoch()
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 558, in run_training_epoch
self.run_on_epoch_end_hook(epoch_output)
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 806, in run_on_epoch_end_hook
self.trainer.call_hook('on_epoch_end')
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1102, in call_hook
trainer_hook(*args, **kwargs)
File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py", line 115, in on_epoch_end
callback.on_epoch_end(self, self.lightning_module)
File "./cube3/networks/lemmatizer.py", line 278, in on_epoch_end
acc = metrics["meta_acc"]
KeyError: 'meta_acc'
This is because the metrics
in my code which is actually just the trainer.callback_metrics
is an empty dict. Furthermore, this fails in epoch 0, after the sanity check ,which finishes fine (I printed the metrics and in prints the 0 accuracy I expected from the sanity check).
What I tried is to switch the on_epoch_end
with on_validation_epoch_end
and it works. This led me to the conclusion that since on_epoch_end returns an empty dict while on_validation_epoch_end
returns a filled-in dict from the validation_epoch_end
in the pl module, the on_epoch_end
is called in an incorrect order.
Again, this worked well with 1.0.8. (donβt know in what latter version this behaviour changed).
Expected behavior
on_epoch_end
should have the metrics from validation_epoch_end
Environment
- CUDA:
- GPU:
- GeForce RTX 2080 Ti
- available: True
- version: 11.1
- GPU:
- Packages:
- numpy: 1.18.3
- pyTorch_debug: False
- pyTorch_version: 1.8.0+cu111
- pytorch-lightning: 1.2.3
- tqdm: 4.56.0
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.2
- version: #140-Ubuntu SMP Thu Jan 28 05:20:47 UTC 2021
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (4 by maintainers)
the order for the calls is:
on_validation_end
is the last one so Iβd suggest donβt use this one to log anything since some callbacks likeModelCheckpoint
/EarlyStopping
use this hook and expects everything to be present there if you want to monitor something.Also
on_epoch_end
is called after every train/eval/test epoch end so Iβd suggest useon_validation_epoch_end
in your use-case.Also, I donβt think logging is supported in
on_validation_end
hook.ok, I see whatβs happening here. the problem is with the
on_epoch_end
hook. This hook is called after every loop irrespective of train/eval/test. When itβs called at the end of training_loop, it expectsval/meta_acc
but val loop is called after this, and sinceval/meta_acc
is not logged yet itβs raising an error here. So all you need is to do the val-related stuff inside theon_validation_epoch_end
hook.