Dice Loss/Score question
See original GitHub issueHey Eugene,
First of all, thank you for this very useful package. I’m transferring my environment from TF to Pytorch now and having your advanced losses is very helpful. However, when I trained the same model on the same data using same loss functions in both frameworks, I noticed that I get very different loss numbers (I’m using multilabel approach). Digging a little deeper in your code I noticed that when you calculate the Dice Loss you always calculate per sample AND per channel loss and then average it. I don’t understand why are you doing the per channel calculation ad averaging, and not the Dice loss for all classes together. I can show What I mean on a dummy example below:
Let’s prepare 2 dummy multilabel matrices - ground truth (d_gt) and prediction (d_pr) with 3 classes each, 0 Red, 1 Green and 2 Blue:
d_gt = np.zeros(shape=(20,20,3))
d_gt[5:10,5:10,0] =1
d_gt[10:15,10:15,1] =1
d_gt[:,:,2] = (1 - d_gt.sum(axis=-1, keepdims=True)).squeeze()
plt.imshow(d_gt)
d_pr = np.zeros(shape=(20,20,3))
d_pr[4:9,4:9,0] =1
d_pr[11:14,11:14,1] =1
d_pr[:,:,2] = (1 - d_pr.sum(axis=-1, keepdims=True)).squeeze()
plt.imshow(d_pr)
One can see that (using Dice Loss = 1- Dice Score):
- Dice Loss for Red is 1- ((16+ 16) / (25+ 25)) = 0.36
- Dice Loss for Green is 1 - ((9+9)/(9+25) = 0.4706
- Dice Loss for Blue is 1 - ((341+341)/(350+366)) = 0.0474
However, total Dice Loss for the whole picture is 1 - (2*(16+9+341)/(2*400) = 0.085
After wrapping them into tensors
d_gt_tensor = torch.from_numpy(np.transpose(d_gt,(2,0,1))).unsqueeze(0)
d_pr_tensor = torch.from_numpy(np.transpose(d_pr,(2,0,1))).unsqueeze(0)
what your Dice Loss (with from_logits=False) is returning is 0.2927 which is the averaged loss of individual channels instead of the total loss. The culprit seems to be passing dims=(0,2) to the soft_dice_score function, I think that dims=(1,2) should be passed instead to get individual scores for each item in the batch? Unless this behaviour is intended but then I’d need some more explanation why.
Second smaller question regrading your Dice Loss is why you use from_logits= True by default?
Thanks in advance!
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (2 by maintainers)
Top GitHub Comments
The signature of all losses defined as
forward(predictions, targets)
. So the second argument defines the ground-truth values. Secondly, Dice metric is not defined when there are not positive targets (As you pass emptyzeros
tensor). To avoidNaN
in loss, it falls back to zero. Hope this clarifies why you getting zero output if the first case.Thanks for your reply. It is clear to me now that you would prefer a zero loss instead of a
NaN
loss.However, why would you say that the dice loss is not defined when there are no positive targets? Looking up the Dice metric at Wikipedia seems to suggest it is just an intersection over union for two sets. Do you have a source for this which I could read to understand it better?