Resuming training resets the logged step number
See original GitHub issue🐛 Bug
The change introduced in https://github.com/PyTorchLightning/pytorch-lightning/pull/11805 causes a reset to the logged step number. https://github.com/PyTorchLightning/pytorch-lightning/blob/49a4a36ad45b937dd0124ecfb08eb7400dbf3950/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py#L122
To Reproduce
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run(ckpt_path=None):
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
num_sanity_val_steps=0,
max_epochs=2,
enable_model_summary=False,
callbacks=ModelCheckpoint(dirpath="checkpoints", save_top_k=-1, filename="{epoch}", save_on_train_epoch_end=False),
log_every_n_steps=1,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path=ckpt_path)
if __name__ == "__main__":
run()
run("checkpoints/epoch=0.ckpt")
The script will create two tensorboard logs:
- version_0: steps 0 to 63
- version_1: steps 0 to 31
Expected behavior
- version_1: steps 31 to 63
This was the behavior before https://github.com/PyTorchLightning/pytorch-lightning/pull/11805
Environment
- PyTorch Lightning Version (e.g., 1.5.0): master (49a4a36)
- Fault-tolerant training is off (PL_FAULT_TOLERANT_TRAINING=0)
cc @tchaton @rohitgr7 @akihironitta @awaelchli @ananthsub @ninginthecloud @carmocca
Issue Analytics
- State:
- Created 2 years ago
- Reactions:2
- Comments:7 (2 by maintainers)
Top Results From Across the Web
Resuming training resets the logged step number #12274
Resuming training resets the logged step number #12274 ... The change introduced in #11805 causes a reset to the logged step number.
Read more >tensorflow - Resuming training with fit resets batch step to 0
The problem is that the step count starts again from 0 at each restart of the training. How can I fix this? Presumably...
Read more >Manage device restarts after updates (Windows 10)
Delay automatic reboot · Turn off auto-restart for updates during active hours prevents automatic restart during active hours. · No auto-restart ...
Read more >Trainer — PyTorch Lightning 1.8.5.post0 documentation
Stop training after this number of global steps. Training will stop if max_steps or max_epochs have reached (earliest). # Default (disabled) trainer = ......
Read more >How To Reset Windows 10 From The Login Screen - YouTube
You can use PassFab 4WinKey ( Windows Password Key ) to unlock Windows Administrator password within 3 steps http://bit.ly/3q707KbIn this ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Hi, this workaround seems to work for my use case:
cc: @carmocca wdyt?