question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Multi-class 2D segmentation: merged classes in prediction

See original GitHub issue

Hello, everyone!

I recently discovered MONAI and I am trying to get it to work on some of our projects. I really like the modularity of the library, but I fear that I am still confused with some of the steps. In particular, I am now facing a strange issue with the results of the network predictions.

I adapted the example from https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/brats_segmentation_3d.ipynb for a 2D case of segmentation of cells. As ground truth for the training, I am using a (3-channel, one-hot) mask with one channel for the background pixels, one for the cell bodies and one for the cell boundaries (see the details below). The prediction seems to work more or less okay, with the exception that channel 0 contains the background pixels (as expected), channel 1 contains the cell-body pixels (as expected), and channel 2 contains cell-body + edge pixels (see second panel from the right in the figure below). If I subtract channel 1 from channel 2, I get a clean cell boundary signal (right-most panel), but I am confused as the reason why is the interior filled.

Data

  • Images are gray-value, one channel. Size of the tensor (without the batch dimension) is (1, 700, 1100). Intensities are normalized and torch.float32.
  • Masks are one-hot (3 channel, binary), with channel 0 being the background class, channel 1 the cell bodies, and channel 2 the cell boundaries. Size of the tensor (without the batch dimension) is (3, 700, 1100). Type is float32.

A ROI is applied during training and validation/prediction (see below).

Hyperparameters

Training and predictions are run on a (256, 256) ROI.

Training

The architecture of the network is the same as in the example, with the exception of setting dimensions=2 and in_channels=1. As in the example, out_channels=3.

loss_function = DiceLoss(include_background=False, to_onehot_y=False, sigmoid=True, squared_pred=True)
optimizer = Adam(model.parameters(), 1e-3, weight_decay=1e-4, amsgrad=True)

In contrast to the 3D example, I added include_background=False to the loss function and increased the learning rate and decay by one order of magnitude.

Validation

I set include_background=False for the metric calculation. The rest is as in the example.

dice_metric = DiceMetric(include_background=False, reduction="mean")
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

Prediction

Same changes as in the validation.

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])

Is there any obvious reason why prediction of channel 2 seems to accumulate the predictions of expected channels 1 and 2?

This is a small inset of a prediction that shows the issue:

prediction

Legend

  • 0: prediction of background pixels;
  • 1: prediction of cell-body pixels;
  • 2: prediction of cell-boundary pixels;
  • 2 - 1: predicted channel 1 subtracted from predicted channel 2 (pretty close to ground truth, not shown).

Thanks a lot for any suggestions! Aaron

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
aarponcommented, Nov 27, 2020

A few results and observations:

  • Changing sigmoid=True to softmax=True in the loss function indeed did the trick! The cell body and cell boundary classes are not neatly separated.
  • To handle the large class imbalance, I tried switching to GeneralizedDiceLoss as suggested, but I get the following error (I tried both monai-weekly and the latest commit from the repository):
Traceback (most recent call last):
  File ".../Segmenter.py", line 298, in <module>
    loss = loss_function(outputs, labels)
  File ".../envs/deep_tools/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File ".../envs/deep_tools/lib/python3.8/site-packages/monai/losses/dice.py", line 345, in forward
    f: torch.Tensor = 1.0 - (2.0 * (intersection * w).sum(1) + self.smooth_nr) / (
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Is there some preparation step that needs to be performed before the output of the forward pass is ready for the GeneralizedDiceLoss function?

Thanks!

0reactions
wylicommented, Nov 28, 2020

that looks like an issue in the loss function, could you try batch=False in GeneralizedDiceLoss

Read more comments on GitHub >

github_iconTop Results From Across the Web

Multi-class Image Segmentation with Unet | by Kian - Medium
It is a form of pixel-level prediction where pixels in the image are grouped under several categories as opposed to image classification where ......
Read more >
Learning Multi-Class Segmentations From Single-Class ...
These methods describe models for the segmentation of individual organs, and the separate seg- mentations are fused together to produce the final outlines....
Read more >
Post-process multi-class predictions for image segmentation?
Example: consider an image of a forest with 5 different birds. Now im trying to output an image that has segmented the forest...
Read more >
Multi-class medical image segmentation using one-vs-rest ...
Approach: Given an image slice, we construct multiple one-vs-rest graphs, each for a tissue class, for inference of a conditional random field ( ......
Read more >
208 - Multiclass semantic segmentation using U-Net - YouTube
Code generated in the video can be downloaded from here: https://github.com/bnsreenu/python_for_microscopistsThe dataset used in this video ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found