question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Training instability with Dice Loss/Tversky Loss

See original GitHub issue

I am training a 2D UNet to segment fetal MR images using MONAI and I have been observing some instability in the training when using MONAI Dice loss formulation. After some iteration, the loss jumps up and the network stops learning, as the gradients drop to zero. Here is an example (orange is loss on training set computed over 2D slices, blue is loss on validation computed over 3D volume). image

After investigating several aspects (using the same deterministic seed), I’ve narrowed down the issue to the presence of the smooth term in both the numerator and denominator of the Dice Loss: f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth)

When using the formulation: f = 1.0 - (2.0 * intersection) / (denominator + smooth) without the smooth term in the numerator, the training was stable and no longer showed unexpected behaviour: image [Note: this experiment was trained for much longer to make sure the jump would not appear later in the training]

The same pattern was observed also for the Tversky Loss, so it could be worth investigating the stability of the losses to identify the best default option.

Software version MONAI version: 0.1.0+84.ga683c4e.dirty Python version: 3.7.4 (default, Jul 9 2019, 03:52:42) [GCC 5.4.0 20160609] Numpy version: 1.18.2 Pytorch version: 1.4.0 Ignite version: 0.3.0

Training information Using MONAI PersistentCache 2D UNet (as default in MONAI) Adam optimiser, LR = 1e-3, no LR decay Batch size: 10

Other tests The following aspects were investigated but did not solve the instability issue:

The following losses were also investigated

  • Binary Cross Entropy --> stable
  • Dice Loss + Binary Cross Entropy --> unstable
  • Dice Loss (no smooth at numerator) + Binary Cross Entropy --> stable
  • Tversky Loss --> Unstable
  • Tversky Loss (no smooth at numerator) --> stable

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:8
  • Comments:38 (14 by maintainers)

github_iconTop GitHub Comments

13reactions
martaranzinicommented, Sep 21, 2020

Hi all,

First, apologies for the delay in getting back to you about this issue – we have been running a few experiments to put together a MONAI and nnU-Net comparison and this required quite some time. We hope you will find our results interesting and informative.

@LucasFidon kindly ran a few experiments with nnU-Net and this allowed us to identify some implementation differences. Both MONAI and nnU-Net use the Dice formulation: CodeCogsEqn However, as a default MONAI computes the Dice per element in the batch and then the loss is averaged across the batch (we will refer to this simply as “Dice”). We noticed that in nnU-Net the Dice is instead computed directly as a single value across the whole batch (i.e. not per image). The average is computed only across the channels, not across the batch elements. We refer to this approach as “Batch Dice”. We experimented with both formulations in both frameworks, and also tested for Dice loss and Dice + Cross entropy loss at training, as well as for the use of a single- or 2-channel approach.

The performance of the different trained models on the validation set is reported below (note: they are separated in groups A-E for our own clinical interpretation, but the group separation is not particularly relevant for this issue). The sampling strategy at training is also reported. monai_comparison1

A few specifications:

  • MONAI – Dice no smooth at numerator used the formulation: CodeCogsEqn-1
  • nnU-Net – Batch Dice + Xent, 2-channel, ensemble indicates ensemble performance from 5-fold cross validation at training
  • NeuroImage indicates a published two-step approach on our dataset, and it is reported just for reference.

We gather two main observations from these experiments:

  1. Training instability: On our dataset, nnU-Net did not present the training instability observed with the default MONAI implementation of Dice. This is also confirmed when known implementation differences were ruled out (nnU-Net – Dice, 2-channel, uniform sampling). Also, nnU-Net generally provides better performance. @FabianIsensee, are there any other implementation differences that could justify our results?
  2. Dice vs Batch Dice: In both frameworks, the Batch Dice implementation clearly outperforms the “normal” Dice computation. This could be an interesting feature to be added in MONAI – happy to open another issue/PR about this.

@wyli: I did retrain also the two-channel with sigmoid instead of softmax. With respect to the single-channel, the gradients do not drop to zero (with either sigmoid and softmax), but I still observe some instability in the loss which I cannot fully explain: monai_comparison2

Looking forward to hearing your comments, and happy to run more experiments to investigate this further!

5reactions
martaranzinicommented, Oct 28, 2020

Hi all,

@wyli: I finally have the updates on the DynUnet. First, I would like to point out that together with @LucasFidon we identified a huge source of discrepancy in patch size and batch size between the manually selected values in MONAI and the automatically determined ones in nnU-Net. In MONAI I was using a way smaller patch size (roughly a factor of 5), which explains the very large gap of performance.

I did rerun the experiments with the “standard” MONAI UNet, but using the same patch and batch size as determined by nnU-Net. These results are reported in red (first boxplot of each group) in the figure below. For the DynUnet, wrt the MONAI tutorial I only modified the spacing transform to apply it only in the x-y plane, but no change of spacing along z (as our data is heavily affected by out-of-plane motion artefacts). All the training has been performed in 2D. Orange and light orange boxplots are the results with Dice + Xent and Batch Dice + Xent as losses respectively.

Here are the results on our validation sets (not seen at training): monai_comparison_dynunet

Overall, we managed to reduce the gap substantially compared to our previous results with very minor modifications of existing tools in MONAI. However, using the optimal hyperparameters as determined by nnU-Net played a big role in this.

Note: for dynUnet, both Dice and Batch Dice I kept the original MONAI formulation of Dice: image with smooth=1e-5. In this case, it did not show the previously observed instability at training. However, with the “standard” UNet, this formulation would still show the instability, despite the optimised patch and batch size. For that experiment, the smooth term at numerator was set to 0.

Hope this helps, and please let me know if I can help further 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

Dealing with class imbalanced image datasets using the Focal ...
A comparison of losses in class imbalanced problems and why the Focal Tversky Loss might be the best option for you.
Read more >
A Novel Focal Tversky loss function with improved Attention U ...
Abstract. We propose a generalized focal loss function based on the Tversky index to address the issue of data imbalance in medical image...
Read more >
Unified Focal loss: Generalising Dice and cross entropy ...
However, the Dice loss gradient is inherently unstable, most evident with highly class imbalanced data where gradient calculations involve small denominators ( ...
Read more >
Calibrating the Dice Loss to Handle Neural Network ...
In this study, we provide a simple yet effective extension of the DSC loss, named the DSC++ loss, that selectively modulates the penalty ......
Read more >
Loss Function Library - Keras & PyTorch | Kaggle
Tversky and Focal-Tversky loss benefit from very low learning rates, of the order 5e-5 to 1e-4. · In general, if a loss function...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found