Checkpoint not saved even though loss improving (Pytorch Lighting example)
See original GitHub issueDescribe the bug I was trying to implement the PyTorching example: https://github.com/Project-MONAI/MONAI/blob/master/examples/notebooks/spleen_segmentation_3d_lightning.ipynb (non-notebook code attached) and noticed that the checkpoints were not saved
Expected behavior
As in screencap mean_val_dice
was improving, but no checkpoint was saved to disk? Is the problem similar to here https://github.com/PyTorchLightning/pytorch-lightning/issues/511
I replaced
return {'log': tensorboard_logs}
from def validation_epoch_end(self, outputs):
to
return {'mean_val_dice': torch.tensor(mean_val_dice), 'log': tensorboard_logs}
And the checkpoints were saved now, but the {val_loss:.2f}-{val_dice:.2f}
were not updated in checkpoint filename?
Screenshots
Environment (please complete the following information):
- OS: Ubuntu 18.04, see screencap for
monai.config.print_config()
Additional context Non-notebook code
import os
import sys
import glob
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import monai
from monai.transforms import \
Compose, LoadNiftid, AddChanneld, ScaleIntensityRanged, RandCropByPosNegLabeld, \
CropForegroundd, RandAffined, Spacingd, Orientationd, ToTensord
from monai.data import list_data_collate
from monai.inferers import sliding_window_inference
from monai.networks.layers import Norm
from monai.metrics import compute_meandice
from monai.utils import set_determinism
from pytorch_lightning import LightningModule, Trainer, loggers
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
monai.config.print_config()
class Net(LightningModule):
def __init__(self):
super().__init__()
self._model = monai.networks.nets.UNet(dimensions=3, in_channels=1, out_channels=2,
channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2),
num_res_units=2, norm=Norm.BATCH)
self.loss_function = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)
self.best_val_dice = 0
self.best_val_epoch = 0
def forward(self, x):
return self._model(x)
def prepare_data(self):
# set up the correct data path
# 1.6 GB dataset from https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2
# http://medicaldecathlon.com/
# Medical Segmentation Decathlon: Generalisable 3D Semantic Segmentation
data_root = '/home/petteri/Task09_Spleen'
train_images = sorted(glob.glob(os.path.join(data_root, 'imagesTr', '*.nii.gz')))
train_labels = sorted(glob.glob(os.path.join(data_root, 'labelsTr', '*.nii.gz')))
data_dicts = [{'image': image_name, 'label': label_name}
for image_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
# set deterministic training for reproducibility
set_determinism(seed=0)
# define the data transforms
train_transforms = Compose([
LoadNiftid(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=('bilinear', 'nearest')),
Orientationd(keys=['image', 'label'], axcodes='RAS'),
ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=['image', 'label'], source_key='image'),
# randomly crop out patch samples from big image based on pos / neg ratio
# the image centers of negative samples must be in valid image area
RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,
neg=1, num_samples=4, image_key='image', image_threshold=0),
# user can also add other random transforms
# RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1.0, spatial_size=(96, 96, 96),
# rotate_range=(0, 0, np.pi/15), scale_range=(0.1, 0.1, 0.1)),
ToTensord(keys=['image', 'label'])
])
val_transforms = Compose([
LoadNiftid(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=('bilinear', 'nearest')),
Orientationd(keys=['image', 'label'], axcodes='RAS'),
ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=['image', 'label'], source_key='image'),
ToTensord(keys=['image', 'label'])
])
# we use cached datasets - these are 10x faster than regular datasets
self.train_ds = monai.data.CacheDataset(
data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4
)
self.val_ds = monai.data.CacheDataset(
data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4
)
# self.train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# self.val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
def train_dataloader(self):
train_loader = DataLoader(self.train_ds, batch_size=2, shuffle=True,
num_workers=4, collate_fn=list_data_collate)
return train_loader
def val_dataloader(self):
val_loader = DataLoader(self.val_ds, batch_size=1, num_workers=4)
return val_loader
def configure_optimizers(self):
optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
return optimizer
def training_step(self, batch, batch_idx):
images, labels = batch['image'], batch['label']
output = self.forward(images)
loss = self.loss_function(output, labels)
tensorboard_logs = {'train_loss': loss.item()}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_idx):
images, labels = batch['image'], batch['label']
roi_size = (160, 160, 160)
sw_batch_size = 4
outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward)
loss = self.loss_function(outputs, labels)
value = compute_meandice(y_pred=outputs, y=labels, include_background=False,
to_onehot_y=True, mutually_exclusive=True)
return {'val_loss': loss, 'val_dice': value}
def validation_epoch_end(self, outputs):
val_dice = 0
num_items = 0
for output in outputs:
val_dice += output['val_dice'].sum().item()
num_items += len(output['val_dice'])
mean_val_dice = val_dice / num_items
tensorboard_logs = {'val_dice': mean_val_dice}
if mean_val_dice > self.best_val_dice:
self.best_val_dice = mean_val_dice
self.best_val_epoch = self.current_epoch
print('Validation loss improved, a new checkpoint _should be saved_ (Petteri)')
print('current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format(
self.current_epoch, mean_val_dice, self.best_val_dice, self.best_val_epoch))
return {'log': tensorboard_logs}
## Run the training
# initialise the LightningModule
net = Net()
# set up loggers and checkpoints
tb_logger = loggers.TensorBoardLogger(save_dir='logs')
checkpoint_callback = ModelCheckpoint(filepath='logs/{epoch}-{val_loss:.2f}-{val_dice:.2f}')
# initialise Lightning's trainer.
trainer = Trainer(gpus=[0],
max_epochs=600,
logger=tb_logger,
checkpoint_callback=checkpoint_callback,
show_progress_bar=True,
num_sanity_val_steps=1
)
# train
trainer.fit(net)
print('train completed, best_metric: {:.4f} at epoch {}'.format(net.best_val_dice, net.best_val_epoch))
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (6 by maintainers)
Top GitHub Comments
Hi @Nic-Ma, I think we should. I can go a PR with the update. Mark
Sounds good, we will release MONAI v0.2 soon, if you can submit a quick PR, that would be great. thanks in advance!