question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Invoke training_epoch_end before validation_epoch_end

See original GitHub issue

Looking at https://github.com/PyTorchLightning/pytorch-lightning/issues/2816#issuecomment-669336994 the current event order is:

on_validation_epoch_start
on_validation_epoch_end
on_train_start
on_epoch_start
on_train_epoch_start
on_validation_start
on_validation_epoch_start
on_validation_epoch_end
on_validation_end
on_epoch_end
on_train_epoch_end
on_epoch_start
on_train_epoch_start
on_validation_start
on_validation_epoch_start
on_validation_epoch_end
on_validation_end
on_epoch_end
on_train_epoch_end
on_train_end

I am expecting the following order instead:

on_validation_epoch_start
on_validation_epoch_end
on_epoch_start
* on_train_start
on_train_epoch_start
* on_train_epoch_end
* on_train_end
on_validation_start
on_validation_epoch_start
on_validation_epoch_end
on_validation_end
on_epoch_end
on_epoch_start
on_train_epoch_start
on_validation_start
on_validation_epoch_start
on_validation_epoch_end
on_validation_end
on_epoch_end
on_train_epoch_end

I moved three events: on_train_start, on_train_epoch and on_train_end.

Justification: The training phase always completes before the validation phase begins. The old callback order does not reflect reality.

In terms of LightningModule’s events, I am seeing the following invocation order:

training_step
training_step_end
validation_step
validation_step_end
validation_epoch_end
training_epoch_end

instead of:

training_step
training_step_end
* training_epoch_end
validation_step
validation_step_end
validation_epoch_end

Justification: The actual invocation order contradicts the order specified in the method documentation.

I’m trying to set a metric in validation_epoch_end that depends on the training loss but I cannot do this unless training_epoch_end completes before validation_epoch_end.

Additionally, I think that anyone who reads train_loss from inside validation_epoch_end will actually get the value of the metric from the previous epochs. Meaning, if I read train_loss during the validation of epoch 10 I will actually see train_loss from epoch 9. Anyone making use of this data thus runs the risk of corrupting their training.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:8 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
tchatoncommented, Oct 11, 2021

Dear @cowwoc,

Thanks for the clarification.

  1. If you use a LossTensor as described below, you have full control on when the reduction is being performed, so you can access the train_loss_epoch on validation epoch end.

  2. Yes, you can but you need to rely a bit more on the Trainer internal. The Trainer keeps track of the is_last_batch or you can use on_train_start, on_validation_start.

class LossTensor(Metric):

    def __init__(self):
        super().__init__()
        self.add_state("loss", torch.tensor(.0), reduce_fx=torch.mean)
        self.add_state("counter", torch.tensor(.0), reduce_fx=torch.sum)

    def update(self, loss):
        self.loss += loss
        self.counter += 1

    def compute(self):
        return self.loss / self.counter


class Model(LightningModule):

    def __init__(self):
        super().__init__()
        self.train_loss = LossTensor()
        self.val_loss = LossTensor()

    def on_train_start(self):
        self.t0 = time.time()

    def training_step(self, batch, batch_idx):
        loss = ...
        self.train_loss(loss) # accumulate
        return loss

    def training_batch_end(self, batch, batch_idx):
        # check last training epoch batch
        if self.trainer.fit_loop.epoch_loop.progress_bar.is_last_batch:
            self.t1 = time.time()

    def validation_step(self, batch, batch_idx):
        loss = ...
        self.val_loss(loss)

    def validation_epoch_end(self, outputs):
        self.log("trial_loss", self.train_loss.compute() + self.val_loss.compute() + (t1 - t0)) # perform reduction.
    def on_train_start(self):
        self.t0 = time.time()

    def on_validation_start(self):
        if not self.trainer.sanity_checking
            self.t1 = time.time()
0reactions
cowwoccommented, Apr 6, 2022

@tchaton Is this workaround no longer needed due to “Re-define the current_epoch boundary” changes in https://github.com/PyTorchLightning/pytorch-lightning/releases/tag/1.6.0?

Does training_epoch_end now get invoked before validation_epoch_end?

Read more comments on GitHub >

github_iconTop Results From Across the Web

LightningModule - PyTorch Lightning - Read the Docs
Called in the training loop at the very end of the epoch. To access all batch outputs at the end of the epoch,...
Read more >
pytorch lightning epoch_end/validation_epoch_end
Based on the structure, I assume you are using pytorch_lightning . validation_epoch_end() will collect outputs from validation_step() ...
Read more >
Early Stopping to avoid overfitting in neural network- Keras
Training will stop when the chosen performance measure stops improving. To discover the training epoch on which training was stopped, the “verbose” argument ......
Read more >
Use Early Stopping to Halt the Training of Neural Networks At ...
Instead, the model is evaluated on the validation dataset at the end of each training epoch. Want Better Results with Deep Learning? Take...
Read more >
Understanding logging and validation_step ...
Lightning will take care of it by automatically aggregating your loss that you logged in the {training|validation}_step at the end of each epoch...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found