Training instability with Dice Loss/Tversky Loss
See original GitHub issueI am training a 2D UNet to segment fetal MR images using MONAI and I have been observing some instability in the training when using MONAI Dice loss formulation. After some iteration, the loss jumps up and the network stops learning, as the gradients drop to zero. Here is an example (orange is loss on training set computed over 2D slices, blue is loss on validation computed over 3D volume).
After investigating several aspects (using the same deterministic seed), I’ve narrowed down the issue to the presence of the smooth
term in both the numerator and denominator of the Dice Loss:
f = 1.0 - (2.0 * intersection + smooth) / (denominator + smooth)
When using the formulation:
f = 1.0 - (2.0 * intersection) / (denominator + smooth)
without the smooth term in the numerator, the training was stable and no longer showed unexpected behaviour:
[Note: this experiment was trained for much longer to make sure the jump would not appear later in the training]
The same pattern was observed also for the Tversky Loss, so it could be worth investigating the stability of the losses to identify the best default option.
Software version MONAI version: 0.1.0+84.ga683c4e.dirty Python version: 3.7.4 (default, Jul 9 2019, 03:52:42) [GCC 5.4.0 20160609] Numpy version: 1.18.2 Pytorch version: 1.4.0 Ignite version: 0.3.0
Training information Using MONAI PersistentCache 2D UNet (as default in MONAI) Adam optimiser, LR = 1e-3, no LR decay Batch size: 10
Other tests The following aspects were investigated but did not solve the instability issue:
- Gradient clipping
- Different optimisers (SGD, SGD + Momentum)
- Transforming the binary segmentations to a two-channel approach ([background segmentation, foreground segmentation])
- Choosing
smooth = 1.0
as default here (https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/loss_functions/dice_loss.py). However, this made the behaviour even more severe and the jump would happen sooner in the training.
The following losses were also investigated
- Binary Cross Entropy --> stable
- Dice Loss + Binary Cross Entropy --> unstable
- Dice Loss (no smooth at numerator) + Binary Cross Entropy --> stable
- Tversky Loss --> Unstable
- Tversky Loss (no smooth at numerator) --> stable
Issue Analytics
- State:
- Created 3 years ago
- Reactions:8
- Comments:38 (14 by maintainers)
Top GitHub Comments
Hi all,
First, apologies for the delay in getting back to you about this issue – we have been running a few experiments to put together a MONAI and nnU-Net comparison and this required quite some time. We hope you will find our results interesting and informative.
@LucasFidon kindly ran a few experiments with nnU-Net and this allowed us to identify some implementation differences. Both MONAI and nnU-Net use the Dice formulation: However, as a default MONAI computes the Dice per element in the batch and then the loss is averaged across the batch (we will refer to this simply as “Dice”). We noticed that in nnU-Net the Dice is instead computed directly as a single value across the whole batch (i.e. not per image). The average is computed only across the channels, not across the batch elements. We refer to this approach as “Batch Dice”. We experimented with both formulations in both frameworks, and also tested for Dice loss and Dice + Cross entropy loss at training, as well as for the use of a single- or 2-channel approach.
The performance of the different trained models on the validation set is reported below (note: they are separated in groups A-E for our own clinical interpretation, but the group separation is not particularly relevant for this issue). The sampling strategy at training is also reported.
A few specifications:
We gather two main observations from these experiments:
@wyli: I did retrain also the two-channel with sigmoid instead of softmax. With respect to the single-channel, the gradients do not drop to zero (with either sigmoid and softmax), but I still observe some instability in the loss which I cannot fully explain:
Looking forward to hearing your comments, and happy to run more experiments to investigate this further!
Hi all,
@wyli: I finally have the updates on the DynUnet. First, I would like to point out that together with @LucasFidon we identified a huge source of discrepancy in patch size and batch size between the manually selected values in MONAI and the automatically determined ones in nnU-Net. In MONAI I was using a way smaller patch size (roughly a factor of 5), which explains the very large gap of performance.
I did rerun the experiments with the “standard” MONAI UNet, but using the same patch and batch size as determined by nnU-Net. These results are reported in red (first boxplot of each group) in the figure below. For the DynUnet, wrt the MONAI tutorial I only modified the spacing transform to apply it only in the x-y plane, but no change of spacing along z (as our data is heavily affected by out-of-plane motion artefacts). All the training has been performed in 2D. Orange and light orange boxplots are the results with Dice + Xent and Batch Dice + Xent as losses respectively.
Here are the results on our validation sets (not seen at training):
Overall, we managed to reduce the gap substantially compared to our previous results with very minor modifications of existing tools in MONAI. However, using the optimal hyperparameters as determined by nnU-Net played a big role in this.
Note: for dynUnet, both Dice and Batch Dice I kept the original MONAI formulation of Dice: with smooth=1e-5. In this case, it did not show the previously observed instability at training. However, with the “standard” UNet, this formulation would still show the instability, despite the optimised patch and batch size. For that experiment, the smooth term at numerator was set to 0.
Hope this helps, and please let me know if I can help further 😃