question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. ItĀ collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Add dataloader arg to Trainer.test()

See original GitHub issue

šŸš€ Feature

It would be nice if you could use a model for inference using: Trainer.test(model, test_dataloaders=test_loader)

Motivation

This will match the calling structure for Trainer.fit() and allow for test to be called on any dataset multiple times

Pitch

Here’s a use case. After training a model using 5-fold cross-validation, you may want to stack the 5 checkpoints across multiple models, which will require a) out-of-fold (OOF) predictions and b) the 5 test predictions (which will be averaged). It would be cool if a & b could be generated as follows:

for f in folds:
    model1.load_from_checkpoint(f'path/to/model1_fold{f}.ckpt')
    trainer.test(model1,  test_dataloaders=valid_loader)
    trainer.test(model1,  test_dataloaders=test_loader)

    model2.load_from_checkpoint(f'path/to/model2_fold{f}.ckpt'))
    trainer.test(model2,  test_dataloaders=valid_loader)
    trainer.test(model2,  test_dataloaders=test_loader)

Alternatives

Maybe I’m misunderstanding how test works and there is an easier way? Or perhaps the best way to do this is to write an inference function as you would in pure PyTorch?

Additional context

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (7 by maintainers)

github_iconTop GitHub Comments

5reactions
Ir1dcommented, Apr 9, 2020

btw I’m interested in how to ā€œtrain a model using 5-fold cross-validationā€ in PL.

3reactions
williamFalconcommented, Apr 9, 2020

Let’s do this:

  1. Add a test_dataloader method to .test()
  2. remove the test_dataloader from .fit()?
Read more comments on GitHub >

github_iconTop Results From Across the Web

Add dataloader arg to Trainer.test() Ā· Issue #1393 - GitHub
Add dataloader arg to Trainer.test() #1393 ... This will match the calling structure for Trainer.fit() and allow for test to be called onĀ ......
Read more >
Trainer — PyTorch Lightning 1.8.5.post0 documentation
Automatically enabling/disabling grads. Running the training, validation and test dataloaders. Calling the Callbacks at the appropriate times. Putting batchesĀ ...
Read more >
Trainer - Hugging Face
get_test_dataloader — Creates the test DataLoader. log — Logs information on the various objects watching training. create_optimizer_and_scheduler — Sets up theĀ ...
Read more >
Trainer.test() is not working - nlp - PyTorch Forums
I want to test a trained model using the test dataloader test_dl ... TypeError: test() got an unexpected keyword argument 'dataloaders'.
Read more >
How to Create and Use a PyTorch DataLoader
Briefly, a Dataset object loads training or test data into memory, and a ... The max_rows parameter of loadtxt() can be used to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found