Return `resume` support for CheckpointCallback and `.train`, `predict`
See original GitHub issue🚀 Feature Request
Return the resume support for the CheckpointCallback
and Runner.train
, Runner.predict
.
Motivation
That would be a great user-friendly feature – resume is quite common task during deep learning development.
Proposal
- Uncommen the CheckpointCallback code and check its correctness… from my perspective, it should be refactored a bit.
- Add its support to Runner.
- Uncomment the cli interface and update the ConfigRunner and HydraRunner.
Proposed use case:
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl, utils
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)
loaders = {
"train": DataLoader(
MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
),
"valid": DataLoader(
MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
),
}
runner = dl.SupervisedRunner(
input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
# model training
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=1,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
dl.PrecisionRecallF1SupportCallback(
input_key="logits", target_key="targets", num_classes=10
),
dl.AUCCallback(input_key="logits", target_key="targets"),
# catalyst[ml] required ``pip install catalyst[ml]``
# dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
],
logdir="./logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)
# here is the trick
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=1,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
dl.PrecisionRecallF1SupportCallback(
input_key="logits", target_key="targets", num_classes=10
),
dl.AUCCallback(input_key="logits", target_key="targets"),
# catalyst[ml] required ``pip install catalyst[ml]``
# dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
],
# ----
logdir="./logs2",
resume="./logs/checkpoints/train.1.pth",
# ----
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)
Alternatives
Additional context
Checklist
- feature proposal description
- motivation
- extra proposal context / proposal alternatives review
FAQ
Please review the FAQ before submitting an issue:
- I have read the documentation and FAQ
- I have reviewed the minimal examples section
- I have checked the changelog for main framework updates
- I have read the contribution guide
- I have joined Catalyst slack (#__questions channel) for issue discussion
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (8 by maintainers)
Top Results From Across the Web
Trainer — PyTorch Lightning 1.8.5.post0 documentation
Callbacks returned in this hook will extend the list initially given to the Trainer argument, and replace the trainer callbacks should there be...
Read more >Runners — Catalyst 22.04 documentation
Runs model inference on PyTorch DataLoader and returns python generator with model predictions from runner.predict_batch . Parameters. loader – loader to ...
Read more >Multi-Head Model for License Plate OCR in Catalyst
The model takes LP images and returns their texts as strings. ... Each head predicts a character from a predefined alphabet ...
Read more >How to use a cross-validated model for prediction?
With the help of CV, you can assess hyperparameters and compare different models to each other. It's just an alternative to a train/test...
Read more >sklearn.model_selection.cross_val_predict
The data is split according to the cv parameter. Each sample belongs to exactly one test set, and its prediction is computed with...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
We can close it, I guess 🥳
Yes, just started