question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Multiclass segmentation

See original GitHub issue

Firstly thanks a lot for the great repo! ` I am trying to train my data to segment multiple classes and I have followed the code sample given:

# prepare data
preprocessing_fn = get_preprocessing('resnet34')
x = preprocessing_fn(x)

# prepare model
model = Unet(backbone_name='resnet34', encoder_weights='imagenet')
model.compile('Adam', 'binary_crossentropy', ['binary_accuracy'])

# train model
model.fit(x, y, epochs=100, batch_size=1)
model.save('trained_mask.h5')`

My x is a list of (1, 256, 512, 3) (i.e. is an image of dimension 256x512) and my y is my segmentation mask. Currently I have successfully trained when the mask size is (1, 256, 512, 1), however I would like to adapt this code for multi-class-segmentation. Would that be possible? If yes, how?

I have tried two different options so far:

  • defined the label mask as a one-hot vector so that the y has now size (1, 256, 512, 13) (where 13 is the total number of classes considered to label).

    • This however fails with this error: ValueError: Error when checking target: expected sigmoid to have shape (None, None, 1) but got array with shape (256, 512, 13)
  • so then given the error found I tried (probably naively) to use (1, 256, 512, 1) however setting each pixel of the mask with a number from 0 to 13 where 0 stands for unlabelled pixel and any other number is the desired class label.

    • Despite this trains successfully, the result of testing is still a mask bounded to values between 0 and 1, so it does not allow multiple-class segmentation.

Could anyone show me what I am doing wrong and how I could solve the multi-class problem?

Thank you!

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

6reactions
qubvelcommented, Nov 26, 2018

Hi, @freddifederica Yes, multiclass segmentation is possible. Define the model as follows:

# define model with output of N classes, where N > 1
model = Unet('resnet34', classes=N, activation='softmax')

# for multiclass segmentation choose another loss and metric
model.compile('Adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])

I also recommend you to try another network architectures like FPN and PSPNet (they are more suitable for multiclass segmentation problem).

1reaction
qubvelcommented, Nov 27, 2018

Just write before loading model: import segmentation_models

Read more comments on GitHub >

github_iconTop Results From Across the Web

Multi-Class Semantic Segmentation with U-Net & PyTorch
Semantic segmentation is a computer vision task in which every pixel of a given image frame is classified/labelled based on whichever class ...
Read more >
208 - Multiclass semantic segmentation using U-Net - YouTube
Code generated in the video can be downloaded from here: https://github.com/bnsreenu/python_for_microscopistsThe dataset used in this video ...
Read more >
Multiclass semantic segmentation using DeepLabV3+ - Keras
Semantic segmentation, with the goal to assign semantic labels to every pixel in an image, is an essential computer vision task. In this...
Read more >
A Machine Learning Engineer's Tutorial to Transfer Learning ...
In this hands-on tutorial we will review how to start from a binary semantic segmentation task and transfer the learning to suit multi-class...
Read more >
U-net-for-Multi-class-semantic-segmentation - GitHub
This example demonstrates the use of U-net model for pathology segmentation on retinal images. This supports binary and multi-class segmentation.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found