better checking of data returned from training_step
See original GitHub issueš Feature
letās add more validation checks on whatās returned from training_step and provide the user with useful error messages when theyāre not returning the right values.
Motivation
i feel like iāve seen a lot of users confused about what theyāre supposed to return in training_step and validation_step. additionally, i donāt think we document how we treat extra keys as ācallback metricsā very well.
Pitch
what you do you think about adding some structure and validation for Trainerās process_output method?
right now, we have expectations about a set of keys {progress_bar, log, loss, hiddens} and assume everything else is a callback metric. however, this is a silent assumption.
we could instead enforce a more rigid structure:
{
'loss': loss # REQUIRED
'log': {} # optional dict
'progress_bar': {} # optional dict
'hiddens': [h0, c0] # optional collection of tensors
'metrics': {} # optional dict
}
moreover, we can leverage pydantic to do validation automatically and provide useful error message out of the box when data validation fails.
cc @PyTorchLightning/core-contributors
Alternatives
Do nothing, keep things as they are.
Additional context
This would be a backwards incompatible change.
Issue Analytics
- State:
- Created 3 years ago
- Comments:12 (9 by maintainers)

Top Related StackOverflow Question
@Borda given that this proposal is backwards compatible, i think we should get more core contributors to weigh in on the proposed design before moving forward and implementing it.
one thing that is still giving me tension is the fact that thereās a lot of overlap between
log,progress_bar, andmetrics.progress_baralmost always consists of a subset oflog, andmetrics(or as they currently stand, arbitrary keys) are typically used to store temporary values to be collated and logged at the end of an epoch. i think thereās room for improvement here.Shouldnāt we favor the return type to be a strong type? Iāve always wondered why the step return type is not a dataclass or named tuple where loss is a required argument. We could keep the flexibility using some metadata dict argument.