question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TensorBoardLogger and WandbLogger do not track global_step when resuming training from a checkpoint (both manually, and with fault tolerant)

See original GitHub issue

🐛 Bug

When resuming model training from a checkpoint, the TensorboardLogger and WandbLogger will log metrics as if the global_step was reset to 0 (although the global_step in the trainer and pl_module are accurate). This issue arises when manually resuming training from a checkpoint using the ckpt_path arg in Trainer.fit and when doing fault-tolerant training as shown here: https://github.com/PyTorchLightning/pytorch-lightning/blob/1.6.3/pl_examples/fault_tolerant/automatic.py

To Reproduce

I’ve adapted the script linked above to test this, running v 1.6.3 of pytorch-lightning:

import os
import random as python_random
from argparse import ArgumentParser
from time import sleep

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import _logger as log
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.loggers import WandbLogger
import wandb


class RandomGetItemDataset(Dataset):
    """A dataset with random elements generated using global rng from torch, numpy and python."""

    def __init__(self, length, size):
        self.size = size
        self.len = length

    def __getitem__(self, index):
        t = torch.rand(self.size)
        n = torch.from_numpy(np.random.rand(self.size))
        p = torch.tensor([python_random.random() for _ in range(self.size)])
        sample = (index + (t + n + p) / 10).float()
        return sample

    def __len__(self):
        return self.len


class SimpleMLP(LightningModule):
    def __init__(self, fail_on_step: int = -1):
        super().__init__()
        self.layer = torch.nn.Linear(1, 2)
        self.seen_batches = []
        self.fail_on_step = fail_on_step

    def training_step(self, batch, batch_idx):
        if self.global_step == self.fail_on_step:
            log.info(
                f"READY TO BE KILLED WITH SIGTERM SIGNAL. " f"Run `kill -SIGTERM {os.getpid()}` in another terminal."
            )
            # this line is used to wait for you to send the signal to exit gracefully.
            while not self.trainer._terminate_gracefully:
                sleep(0.1)
        batch = batch["data"] if isinstance(batch, dict) else batch
        self.seen_batches.append(torch.stack(batch) if isinstance(batch, list) else batch)
        loss = sum(self.layer(b).sum() for b in batch)
        self.log("loss", loss.item())
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def train_dataloader(self):
        return DataLoader(RandomGetItemDataset(3, 1))


def _run_training(default_root_dir=".", max_epochs=3, fail_on_step: int = -1, ckpt_path=None, logger=True):
    model = SimpleMLP(fail_on_step=fail_on_step)
    trainer = Trainer(default_root_dir=default_root_dir, max_epochs=max_epochs,
                      logger=logger, log_every_n_steps=1)
    trainer.fit(model, ckpt_path=ckpt_path)
    wandb.finish()
    return model.seen_batches, model.parameters()


def main(args):
    seed_everything(42)
    os.environ["PL_FAULT_TOLERANT_TRAINING"] = "automatic"  # active fault tolerant automatic

    ckpt_path = ".pl_auto_save.ckpt"
    auto_restart_ckpt_path_exists = os.path.exists(ckpt_path)

    if args.emulate_kill_signal:
        fail_on_step = -1 if auto_restart_ckpt_path_exists else 4
        completed_batches = 4 if auto_restart_ckpt_path_exists else 5
    else:
        fail_on_step = -1
        completed_batches = 9

    if args.use_tb:
        logger = True
    else:
        logger = WandbLogger(
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=args.wandb_run,
            id=args.wandb_run,
        )

    complete_batches, weights = _run_training(fail_on_step=fail_on_step, logger=logger)
    assert len(complete_batches) == completed_batches

    if not auto_restart_ckpt_path_exists and args.emulate_kill_signal:
        assert os.path.exists(ckpt_path)

    if auto_restart_ckpt_path_exists or not args.emulate_kill_signal:
        log.info([w for w in weights])


if __name__ == "__main__":
    parser = ArgumentParser(description="Fault Tolerant Under Signal Example")
    parser.add_argument(
        "--emulate_kill_signal",
        action="store_true",
        help="Whether you should gracefully kill the process with a `SIGTERM` signal.",
    )
    parser.add_argument(
        "--use_tb",
        action="store_true",
        help="Use TensorBoard instead of WandB.",
    )
    parser.add_argument(
        "-e", "--wandb_entity",
        type=str,
        default=None,
        help="Wandb entity.",
    )
    parser.add_argument(
        "-p", "--wandb_project",
        type=str,
        default=None,
        help="Wandb project.",
    )
    parser.add_argument(
        "-r", "--wandb_run",
        type=str,
        default=None,
        help="Wandb run.",
    )
    main(parser.parse_args())

With tensorboard, running these: python automatic.py --use_tb (without fault) python automatic.py --use_tb --emulate_kill_signal (with fault) python automatic.py --use_tb --emulate_kill_signal (resume from fault)

Results in the following, where the epoch is properly logged, but not the step:

image

With wandb, running these: python automatic.py -e [wandb_entity] -p [wandb_project] -r no_fault (without fault) python automatic.py -e [wandb_entity] -p [wandb_project] -r fault --emulate_kill_signal (with fault) python automatic.py -e [wandb_entity] -p [wandb_project] -r fault --emulate_kill_signal (resume from fault)

Results in the following, where the step is properly logged (because I’m only logging once per step, see #13016), but the global_step is reset.

image

Expected behavior

The trainer/global_step in WandbLogger and step in TensorBoardLogger should properly reflect the global_step state of the trainer/pl_module when resuming from checkpoings (either manually or automatically with fault-tolerant training).

Environment

CUDA:
        - GPU:
        - available:         False
        - version:           10.2
* Packages:
        - numpy:             1.22.4
        - pyTorch_debug:     False
        - pyTorch_version:   1.11.0+cu102
        - pytorch-lightning: 1.6.3
        - tqdm:              4.64.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.10.4
        - version:           #171-Ubuntu SMP Fri Nov 5 11:55:11 UTC 2021

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:2
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
manangoel99commented, Jun 24, 2022

Also adding resume=True as an argument to your WandbLogger initialization might give you much cleaner looking plots!

1reaction
manangoel99commented, Jun 24, 2022

Hey Guys! Engineer from W&B here! Sorry I’m a little late but I managed to track this down to one line

https://github.com/Lightning-AI/lightning/blob/5572797bc80b564286f111861e3d4b408344ae84/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py#L102

The solution to this is to set step = self.trainer.global_step

I’m not entirely sure if this was intentional but I’ve pushed the fix anyways. It has caused some tests to break so looking into those but meanwhile this change should get things up and running.

Read more comments on GitHub >

github_iconTop Results From Across the Web

TensorBoardLogger and WandbLogger do not track ... - GitHub
Bug When resuming model training from a checkpoint, ... resuming training from a checkpoint (both manually, and with fault tolerant) #13163.
Read more >
Changelog — PyTorch Lightning 1.8.5 documentation
Fixed TensorBoardLogger not validating the input array type when logging the model ... Added a function to validate if fault tolerant training is...
Read more >
CHANGELOG.md · zhiqwang/pytorch-lightning - Gitee.com
Fault -tolerant training ... ProgressBar when trainer.fit is not called (#7674); Fixed global step update when the epoch is skipped (#7677); Fixed training...
Read more >
Supervisor: Training Helper for Days-Long Trainings. - haosdent
Can be monitored through TensorBoard. To be able to resume training after a shutdown or a crash the training process must save checkpoints...
Read more >
Pytorch lightning resuming from checkpoint with new data
Yes, when you resume from a checkpoint you can provide the new DataLoader or DataModule during the training and your training will resume...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found