question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. ItΒ collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

on_epoch_end callback is called before on_validation_epoch_end

See original GitHub issue

πŸ› Bug

The on_epoch_end is called before the epoch ends.

What I’m doing:

  • in the pl model I have the validation_epoch_end which computes an accuracy which I log with self.log("val/meta_acc", meta_acc)
  • I have a callback with a single method defined as:
class PrintAndSaveCallback(pl.callbacks.Callback):
    def on_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        metrics["val/meta_acc"]  ### implemented checkpoint saving here using meta_acc value <-- this now fails with KeyError
  • trainer gets this callback as callbacks = [PrintAndSaveCallback],

Up until I upgraded from 1.0.8 (directly now to 1.2.2 and 1.2.3 today) everything was working fine. The validation_epoch_end was logging metrics and in the callback I read them fine. Now, I’m getting :

Epoch 0:  79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–                                 | 42/53 [00:08<00:02,  4.80it/s, loss=2.85, v_num=2]
Traceback (most recent call last):
  File "cube3/trainer.py", line 279, in <module>
    trainer_object.fit()
  File "cube3/trainer.py", line 233, in fit
    trainer.fit(model, train_loader, val_loader)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 514, in fit
    self.dispatch()
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 554, in dispatch
    self.accelerator.start_training(self)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 111, in start_training
    self._results = trainer.run_train()
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in run_train
    self.train_loop.run_training_epoch()
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 558, in run_training_epoch
    self.run_on_epoch_end_hook(epoch_output)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 806, in run_on_epoch_end_hook
    self.trainer.call_hook('on_epoch_end')
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1102, in call_hook
    trainer_hook(*args, **kwargs)
  File "/home/echo/p3.8-test/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py", line 115, in on_epoch_end
    callback.on_epoch_end(self, self.lightning_module)
  File "./cube3/networks/lemmatizer.py", line 278, in on_epoch_end
    acc = metrics["meta_acc"]
KeyError: 'meta_acc'

This is because the metrics in my code which is actually just the trainer.callback_metrics is an empty dict. Furthermore, this fails in epoch 0, after the sanity check ,which finishes fine (I printed the metrics and in prints the 0 accuracy I expected from the sanity check).

What I tried is to switch the on_epoch_end with on_validation_epoch_end and it works. This led me to the conclusion that since on_epoch_end returns an empty dict while on_validation_epoch_end returns a filled-in dict from the validation_epoch_end in the pl module, the on_epoch_end is called in an incorrect order.

Again, this worked well with 1.0.8. (don’t know in what latter version this behaviour changed).

Expected behavior

on_epoch_end should have the metrics from validation_epoch_end

Environment

  • CUDA:
    • GPU:
      • GeForce RTX 2080 Ti
    • available: True
    • version: 11.1
  • Packages:
    • numpy: 1.18.3
    • pyTorch_debug: False
    • pyTorch_version: 1.8.0+cu111
    • pytorch-lightning: 1.2.3
    • tqdm: 4.56.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.2
    • version: #140-Ubuntu SMP Thu Jan 28 05:20:47 UTC 2021

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (4 by maintainers)

github_iconTop GitHub Comments

3reactions
rohitgr7commented, Mar 12, 2021

the order for the calls is:

on_validation_epoch_end
on_epoch_end
on_validation_end

on_validation_end is the last one so I’d suggest don’t use this one to log anything since some callbacks like ModelCheckpoint/EarlyStopping use this hook and expects everything to be present there if you want to monitor something.

Also on_epoch_end is called after every train/eval/test epoch end so I’d suggest use on_validation_epoch_end in your use-case.

Also, I don’t think logging is supported in on_validation_end hook.

1reaction
rohitgr7commented, Mar 12, 2021

ok, I see what’s happening here. the problem is with the on_epoch_end hook. This hook is called after every loop irrespective of train/eval/test. When it’s called at the end of training_loop, it expects val/meta_acc but val loop is called after this, and since val/meta_acc is not logged yet it’s raising an error here. So all you need is to do the val-related stuff inside the on_validation_epoch_end hook.

Read more comments on GitHub >

github_iconTop Results From Across the Web

on_epoch_end callback is called before ... - GitHub
The on_epoch_end is called before the epoch ends. What I'm doing: in the pl model I have the validation_epoch_end which computes an accuracyΒ ......
Read more >
tf.keras.callbacks.Callback | TensorFlow v2.11.0
Abstract base class used to build new callbacks. ... your callbacks into a single callbacks.CallbackList so they can all be called together.
Read more >
Callbacks - Keras 2.1.3 Documentation
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get...
Read more >
TensorflowJS with earlyStopping and Training Logs does not ...
You are mixing different things. OntrainBegin specifies when the callback function is to be executed and tf.callbacks.
Read more >
Callbacks API - Keras
Callbacks API. A callback is an object that can perform actions at various stages of training (e.g. at the start or end of...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found