Epochs terminating early incorrectly
See original GitHub issue🐛 Bug
My understanding is that custom dataloaders are expected to reset themselves and throw StopIteration
when __next__
is called and there is nothing more to yield (e.g. at end of an epoch).
However, starting from pytorch-lightning == 1.6.0
, Lightning appears to recognize that there are no more batches to yield, and this final __next__
call at the end of the epoch is no longer made. Subsequently, on the next epoch, when the first such __next__
call is made, the reset and StopIteration steps are triggered and the epoch ends immediately.
This results in every other epoch being skipped/terminating early:
Time for epoch 0: 0.14461779594421387
Time for epoch 1: 0.000400543212890625 <- skipped
Time for epoch 2: 0.11101579666137695
Time for epoch 3: 0.00034236907958984375 <- skipped
Time for epoch 4: 0.11014986038208008
Time for epoch 5: 0.0003437995910644531 <- skipped
Time for epoch 6: 0.1101064682006836
Time for epoch 7: 0.00034737586975097656 <- skipped
Time for epoch 8: 0.1124570369720459
Time for epoch 9: 0.0006518363952636719 <- skipped
To Reproduce
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
import torch
import os
import time
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
x = x[0]['random'].float() # unpack random data
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def on_train_epoch_start(self):
self.start_time = time.time()
def on_train_epoch_end(self):
print(f'Time for epoch {self.current_epoch}: {time.time() - self.start_time}')
def setup(self, stage=None):
device_id = self.local_rank
shard_id = self.global_rank
num_shards = self.trainer.world_size
mnist_pipeline = BoringPipeline(batch_size=2, device='gpu', device_id=device_id, shard_id=shard_id, num_shards=num_shards, num_threads=8)
self.train_loader = DALIGenericIterator(mnist_pipeline, output_map=['random'], size=100, last_batch_policy=LastBatchPolicy.PARTIAL, auto_reset=True)
def train_dataloader(self):
return self.train_loader
@pipeline_def
def BoringPipeline(device, shard_id, num_shards):
return fn.random.coin_flip(shape=32)
def run():
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_val_batches=0,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=10,
enable_model_summary=False,
)
trainer.fit(model)
if __name__ == "__main__":
run()
Expected behavior
On pytorch-lightning==1.5.10
, there is no such issue, and we get the following times:
Time for epoch 0: 0.12346410751342773
Time for epoch 1: 0.09711241722106934
Time for epoch 2: 0.09668135643005371
Time for epoch 3: 0.09659314155578613
Time for epoch 4: 0.09643316268920898
Time for epoch 5: 0.0983741283416748
Time for epoch 6: 0.09633684158325195
Time for epoch 7: 0.09726572036743164
Time for epoch 8: 0.09633612632751465
Time for epoch 9: 0.10026073455810547
Environment
* CUDA:
- GPU:
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- NVIDIA A40
- available: True
- version: 11.5
* Packages:
- numpy: 1.22.3
- pyTorch_debug: False
- pyTorch_version: 1.11.0+cu115
- pytorch-lightning: 1.6.2
- tqdm: 4.64.0
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.10
- version: #1 SMP Debian 5.15.15-2~bpo11+1 (2022-02-03)
Additional context
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:22 (10 by maintainers)
Top Results From Across the Web
Epochs terminating early incorrectly · Issue #12956 - GitHub
My understanding is that custom dataloaders are expected to reset themselves and throw StopIteration when __next__ is called and there is nothing more...
Read more >Keras is saving the wrong epoch with early stopping
I was wondering if anyone else has seen this happen with Keras recently. In the first screenshot, early stopping is triggered (I set ......
Read more >Use Early Stopping to Halt the Training of Neural Networks At ...
Too many epochs can lead to overfitting of the training dataset, whereas too few may result in an underfit model. Early stopping is...
Read more >Early Stopping to avoid overfitting in neural network- Keras
Too many epochs can lead to overfitting of the training dataset, whereas too few may result in an underfit model.
Read more >skorch.callbacks — skorch 0.12.1 documentation
At the end of the epoch, the average of the scores are determined and also ... As checkpointing is often used in conjunction...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
https://github.com/NVIDIA/DALI/pull/3923 has been merged. It should be available in the next nightly build (tomorrow at earliest) and DALI 1.16.
@carmocca,
It seems the original fix was incomplete - https://github.com/NVIDIA/DALI/pull/4048 should add the missing part.