question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Checkpoint not saved even though loss improving (Pytorch Lighting example)

See original GitHub issue

Describe the bug I was trying to implement the PyTorching example: https://github.com/Project-MONAI/MONAI/blob/master/examples/notebooks/spleen_segmentation_3d_lightning.ipynb (non-notebook code attached) and noticed that the checkpoints were not saved

Expected behavior As in screencap mean_val_dice was improving, but no checkpoint was saved to disk? Is the problem similar to here https://github.com/PyTorchLightning/pytorch-lightning/issues/511

I replaced

return {'log': tensorboard_logs}

from def validation_epoch_end(self, outputs): to

return {'mean_val_dice': torch.tensor(mean_val_dice), 'log': tensorboard_logs}

And the checkpoints were saved now, but the {val_loss:.2f}-{val_dice:.2f} were not updated in checkpoint filename?

Screenshots image

Environment (please complete the following information):

  • OS: Ubuntu 18.04, see screencap for monai.config.print_config()

Additional context Non-notebook code

import os
import sys
import glob
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import monai
from monai.transforms import \
    Compose, LoadNiftid, AddChanneld, ScaleIntensityRanged, RandCropByPosNegLabeld, \
    CropForegroundd, RandAffined, Spacingd, Orientationd, ToTensord
from monai.data import list_data_collate
from monai.inferers import sliding_window_inference
from monai.networks.layers import Norm
from monai.metrics import compute_meandice
from monai.utils import set_determinism
from pytorch_lightning import LightningModule, Trainer, loggers
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

monai.config.print_config()

class Net(LightningModule):
    def __init__(self):
        super().__init__()
        self._model = monai.networks.nets.UNet(dimensions=3, in_channels=1, out_channels=2,
                                               channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2),
                                               num_res_units=2, norm=Norm.BATCH)
        self.loss_function = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)
        self.best_val_dice = 0
        self.best_val_epoch = 0

    def forward(self, x):
        return self._model(x)

    def prepare_data(self):
        # set up the correct data path
        # 1.6 GB dataset from https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2
        # http://medicaldecathlon.com/
        # Medical Segmentation Decathlon: Generalisable 3D Semantic Segmentation
        data_root = '/home/petteri/Task09_Spleen'
        train_images = sorted(glob.glob(os.path.join(data_root, 'imagesTr', '*.nii.gz')))
        train_labels = sorted(glob.glob(os.path.join(data_root, 'labelsTr', '*.nii.gz')))
        data_dicts = [{'image': image_name, 'label': label_name}
                      for image_name, label_name in zip(train_images, train_labels)]
        train_files, val_files = data_dicts[:-9], data_dicts[-9:]

        # set deterministic training for reproducibility
        set_determinism(seed=0)

        # define the data transforms
        train_transforms = Compose([
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=('bilinear', 'nearest')),
            Orientationd(keys=['image', 'label'], axcodes='RAS'),
            ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=['image', 'label'], source_key='image'),
            # randomly crop out patch samples from big image based on pos / neg ratio
            # the image centers of negative samples must be in valid image area
            RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', size=(96, 96, 96), pos=1,
                                   neg=1, num_samples=4, image_key='image', image_threshold=0),
            # user can also add other random transforms
            # RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=1.0, spatial_size=(96, 96, 96),
            #             rotate_range=(0, 0, np.pi/15), scale_range=(0.1, 0.1, 0.1)),
            ToTensord(keys=['image', 'label'])
        ])
        val_transforms = Compose([
            LoadNiftid(keys=['image', 'label']),
            AddChanneld(keys=['image', 'label']),
            Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), interp_order=('bilinear', 'nearest')),
            Orientationd(keys=['image', 'label'], axcodes='RAS'),
            ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=['image', 'label'], source_key='image'),
            ToTensord(keys=['image', 'label'])
        ])

        # we use cached datasets - these are 10x faster than regular datasets
        self.train_ds = monai.data.CacheDataset(
            data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4
        )
        self.val_ds = monai.data.CacheDataset(
            data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4
        )
        # self.train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
        # self.val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)

    def train_dataloader(self):
        train_loader = DataLoader(self.train_ds, batch_size=2, shuffle=True,
                                  num_workers=4, collate_fn=list_data_collate)
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(self.val_ds, batch_size=1, num_workers=4)
        return val_loader

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch['image'], batch['label']
        output = self.forward(images)
        loss = self.loss_function(output, labels)
        tensorboard_logs = {'train_loss': loss.item()}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        images, labels = batch['image'], batch['label']
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        outputs = sliding_window_inference(images, roi_size, sw_batch_size, self.forward)
        loss = self.loss_function(outputs, labels)
        value = compute_meandice(y_pred=outputs, y=labels, include_background=False,
                                 to_onehot_y=True, mutually_exclusive=True)
        return {'val_loss': loss, 'val_dice': value}

    def validation_epoch_end(self, outputs):
        val_dice = 0
        num_items = 0
        for output in outputs:
            val_dice += output['val_dice'].sum().item()
            num_items += len(output['val_dice'])
        mean_val_dice = val_dice / num_items
        tensorboard_logs = {'val_dice': mean_val_dice}
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
            print('Validation loss improved, a new checkpoint _should be saved_ (Petteri)')
        print('current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}'.format(
            self.current_epoch, mean_val_dice, self.best_val_dice, self.best_val_epoch))
        return {'log': tensorboard_logs}

## Run the training
# initialise the LightningModule
net = Net()

# set up loggers and checkpoints
tb_logger = loggers.TensorBoardLogger(save_dir='logs')
checkpoint_callback = ModelCheckpoint(filepath='logs/{epoch}-{val_loss:.2f}-{val_dice:.2f}')

# initialise Lightning's trainer.
trainer = Trainer(gpus=[0],
                  max_epochs=600,
                  logger=tb_logger,
                  checkpoint_callback=checkpoint_callback,
                  show_progress_bar=True,
                  num_sanity_val_steps=1
                  )
# train
trainer.fit(net)

print('train completed, best_metric: {:.4f} at epoch {}'.format(net.best_val_dice, net.best_val_epoch))

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
marksgrahamcommented, Jun 29, 2020

Hi @Nic-Ma, I think we should. I can go a PR with the update. Mark

0reactions
Nic-Macommented, Jun 29, 2020

Hi @Nic-Ma, I think we should. I can go a PR with the update. Mark

Sounds good, we will release MONAI v0.2 soon, if you can submit a quick PR, that would be great. thanks in advance!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Checkpointing — PyTorch Lightning 1.6.1 documentation
Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints.
Read more >
How To Save and Load Model In PyTorch With A Complete ...
A practical example of how to save and load a model in PyTorch. We are going to look at how to continue training...
Read more >
Saving and Loading the Best Model in PyTorch - DebuggerCafe
save_best_model by passing the necessary arguments. If the loss has improved compared to the previous best loss, then a new best model gets ......
Read more >
How to Save and Load Models in PyTorch - Wandb
This article is a tutorial that covers how to correctly save and load your trained machine learning models in PyTorch using Weights &...
Read more >
Saving and loading a general checkpoint in PyTorch
Saving and loading a general checkpoint model for inference or resuming training can be helpful for picking up where you last left off....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found