Multiclass segmentation
See original GitHub issueFirstly thanks a lot for the great repo! ` I am trying to train my data to segment multiple classes and I have followed the code sample given:
# prepare data
preprocessing_fn = get_preprocessing('resnet34')
x = preprocessing_fn(x)
# prepare model
model = Unet(backbone_name='resnet34', encoder_weights='imagenet')
model.compile('Adam', 'binary_crossentropy', ['binary_accuracy'])
# train model
model.fit(x, y, epochs=100, batch_size=1)
model.save('trained_mask.h5')`
My x is a list of (1, 256, 512, 3) (i.e. is an image of dimension 256x512) and my y is my segmentation mask. Currently I have successfully trained when the mask size is (1, 256, 512, 1), however I would like to adapt this code for multi-class-segmentation. Would that be possible? If yes, how?
I have tried two different options so far:
-
defined the label mask as a one-hot vector so that the y has now size (1, 256, 512, 13) (where 13 is the total number of classes considered to label).
- This however fails with this error: ValueError: Error when checking target: expected sigmoid to have shape (None, None, 1) but got array with shape (256, 512, 13)
-
so then given the error found I tried (probably naively) to use (1, 256, 512, 1) however setting each pixel of the mask with a number from 0 to 13 where 0 stands for unlabelled pixel and any other number is the desired class label.
- Despite this trains successfully, the result of testing is still a mask bounded to values between 0 and 1, so it does not allow multiple-class segmentation.
Could anyone show me what I am doing wrong and how I could solve the multi-class problem?
Thank you!
Issue Analytics
- State:
- Created 5 years ago
- Comments:6 (3 by maintainers)
Hi, @freddifederica Yes, multiclass segmentation is possible. Define the model as follows:
I also recommend you to try another network architectures like FPN and PSPNet (they are more suitable for multiclass segmentation problem).
Just write before loading model:
import segmentation_models