question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Return `resume` support for CheckpointCallback and `.train`, `predict`

See original GitHub issue

🚀 Feature Request

Return the resume support for the CheckpointCallback and Runner.train, Runner.predict.

Motivation

That would be a great user-friendly feature – resume is quite common task during deep learning development.

Proposal

  1. Uncommen the CheckpointCallback code and check its correctness… from my perspective, it should be refactored a bit.
  2. Add its support to Runner.
  3. Uncomment the cli interface and update the ConfigRunner and HydraRunner.

Proposed use case:

import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl, utils
from catalyst.data.transforms import ToTensor
from catalyst.contrib.datasets import MNIST

model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.02)

loaders = {
    "train": DataLoader(
        MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32
    ),
    "valid": DataLoader(
        MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32
    ),
}

runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=1,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
        dl.PrecisionRecallF1SupportCallback(
            input_key="logits", target_key="targets", num_classes=10
        ),
        dl.AUCCallback(input_key="logits", target_key="targets"),
        # catalyst[ml] required ``pip install catalyst[ml]``
        # dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
    ],
    logdir="./logs",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    load_best_on_end=True,
)

# here is the trick
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=1,
    callbacks=[
        dl.AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, 3, 5)),
        dl.PrecisionRecallF1SupportCallback(
            input_key="logits", target_key="targets", num_classes=10
        ),
        dl.AUCCallback(input_key="logits", target_key="targets"),
        # catalyst[ml] required ``pip install catalyst[ml]``
        # dl.ConfusionMatrixCallback(input_key="logits", target_key="targets", num_classes=10),
    ],
    # ----
    logdir="./logs2",
    resume="./logs/checkpoints/train.1.pth",
    # ----
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    load_best_on_end=True,
)

Alternatives

Additional context

Checklist

  • feature proposal description
  • motivation
  • extra proposal context / proposal alternatives review

FAQ

Please review the FAQ before submitting an issue:

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:10 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
y-kseniacommented, Nov 10, 2021

We can close it, I guess 🥳

1reaction
y-kseniacommented, Jul 22, 2021

Yes, just started

Read more comments on GitHub >

github_iconTop Results From Across the Web

Trainer — PyTorch Lightning 1.8.5.post0 documentation
Callbacks returned in this hook will extend the list initially given to the Trainer argument, and replace the trainer callbacks should there be...
Read more >
Runners — Catalyst 22.04 documentation
Runs model inference on PyTorch DataLoader and returns python generator with model predictions from runner.predict_batch . Parameters. loader – loader to ...
Read more >
Multi-Head Model for License Plate OCR in Catalyst
The model takes LP images and returns their texts as strings. ... Each head predicts a character from a predefined alphabet ...
Read more >
How to use a cross-validated model for prediction?
With the help of CV, you can assess hyperparameters and compare different models to each other. It's just an alternative to a train/test...
Read more >
sklearn.model_selection.cross_val_predict
The data is split according to the cv parameter. Each sample belongs to exactly one test set, and its prediction is computed with...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found