Add dataloader arg to Trainer.test()
See original GitHub issueš Feature
It would be nice if you could use a model for inference using:
Trainer.test(model, test_dataloaders=test_loader)
Motivation
This will match the calling structure for Trainer.fit() and allow for test to be called on any dataset multiple times
Pitch
Hereās a use case. After training a model using 5-fold cross-validation, you may want to stack the 5 checkpoints across multiple models, which will require a) out-of-fold (OOF) predictions and b) the 5 test predictions (which will be averaged). It would be cool if a & b could be generated as follows:
for f in folds:
model1.load_from_checkpoint(f'path/to/model1_fold{f}.ckpt')
trainer.test(model1, test_dataloaders=valid_loader)
trainer.test(model1, test_dataloaders=test_loader)
model2.load_from_checkpoint(f'path/to/model2_fold{f}.ckpt'))
trainer.test(model2, test_dataloaders=valid_loader)
trainer.test(model2, test_dataloaders=test_loader)
Alternatives
Maybe Iām misunderstanding how test works and there is an easier way? Or perhaps the best way to do this is to write an inference function as you would in pure PyTorch?
Additional context
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (7 by maintainers)
Top Results From Across the Web
Add dataloader arg to Trainer.test() Ā· Issue #1393 - GitHub
Add dataloader arg to Trainer.test() #1393 ... This will match the calling structure for Trainer.fit() and allow for test to be called onĀ ......
Read more >Trainer ā PyTorch Lightning 1.8.5.post0 documentation
Automatically enabling/disabling grads. Running the training, validation and test dataloaders. Calling the Callbacks at the appropriate times. Putting batchesĀ ...
Read more >Trainer - Hugging Face
get_test_dataloader ā Creates the test DataLoader. log ā Logs information on the various objects watching training. create_optimizer_and_scheduler ā Sets up theĀ ...
Read more >Trainer.test() is not working - nlp - PyTorch Forums
I want to test a trained model using the test dataloader test_dl ... TypeError: test() got an unexpected keyword argument 'dataloaders'.
Read more >How to Create and Use a PyTorch DataLoader
Briefly, a Dataset object loads training or test data into memory, and a ... The max_rows parameter of loadtxt() can be used to...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

btw Iām interested in how to ātrain a model using 5-fold cross-validationā in PL.
Letās do this: