question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Proper mask shape for a multiclass task

See original GitHub issue

Hello,

I’m trying to use custom_unet for a multiclass task. In my case all masks inputs have a following shape (BATCH_SIZE, NUMBER_OF_CLASSES, IMAGE_WIDTH, IMAGE_HEIGHT, 1). Since I have 4 different colours on the black background and my images’re 256x256 pixels and I use batch size of 4, I end up with the following shape: (4, 4, 256, 256, 1). Unfortunately, custom_unet doesn’t like this shape giving me this error:

<...>
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\keras_unet\losses.py", line 44, in jaccard_distance
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\ops\math_ops.py", line 899, in binary_op_wrapper
    return func(x, y, name=name)
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\ops\math_ops.py", line 1206, in _mul_dispatch
    return gen_math_ops.mul(x, y, name=name)
  File "C:\Users\E-soft\Anaconda3\envs\Explorium\lib\site-packages\tensorflow_core\python\ops\gen_math_ops.py", line 6698, in mul
    _six.raise_from(_core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [4,4,256,256] vs. [4,256,256,4] [Op:Mul] name: loss/conv2d_18_loss/mul/

It seems that I should reshape my input for masks, Should it be (BATCH_SIZE IMAGE_WIDTH, IMAGE_HEIGHT, NUMBER_OF_CLASSES, 1)?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (4 by maintainers)

github_iconTop GitHub Comments

2reactions
karolzakcommented, Jun 29, 2020

Hi @EtagiBI !

For multiclass semantic segmentation tasks you need to reshape your masks into (BATCH_SIZE IMAGE_WIDTH, IMAGE_HEIGHT, NUMBER_OF_CLASSES).
And to explain why behind it: with this shape you should treat NUMBER_OF_CLASSES as a number of layers in your output masks where each layer represents each class you are trying to predict. In other words your output mask will be a volume of binary masks where each mask represents predictions for specific class.

To give you some visuals on a dummy example:

print(x.shape, y.shape, y_pred.shape)
(1, 256, 256, 3) (1, 256, 256, 3) (1, 256, 256, 3)

utils.plot_imgs(
    org_imgs=np.vstack(
        (x, x, x)),
    mask_imgs=np.stack(
        (y[0][:,:,0],  # first layer of mask is the first class gt
         y[0][:,:,1],   # second layer of mask is the second class gt
         y[0][:,:,2])),   # third layer of mask is the third class gt
    pred_imgs=np.stack(
        (y_pred[0][:,:,0],  # first layer of mask is the first class predictions
         y_pred[0][:,:,1],   # second layer of mask is the second class predictions
         y_pred[0][:,:,2]))   # third layer of mask is the third class predictions
)

image

1reaction
karolzakcommented, Jul 2, 2020

@EtagiBI the code you shared looks fine although I prefer array slicing to extract specific layers. Also GRAY2RGB conversion is unnecessary unless you really need it. Grayscale image is perfectly fine especially for segmentation mask.

If your prediction results look messy your model probably failed to learn anything.
In my dummy example (training on a single image…) I used code below:

from keras_unet.models import custom_unet

model = custom_unet(
    (256,256,3), num_classes=3,
    use_batch_norm=False,
    dropout=0.0, dropout_type='standard',
    output_activation='sigmoid')

from tensorflow.keras.optimizers import Adam
from keras_unet.metrics import iou

model.compile(
    optimizer=Adam(),
    loss='binary_crossentropy',
    metrics=[iou]
)

model.fit(x, y, epochs=200)

When I tried building the network with softmax it couldn’t learn anything.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Extending Binary Image Segmentation to Multi-Class Image ...
Segmentation mask denotes which class a particular pixel of an image belongs to. Segmentation can be considered as a dense classification task ......
Read more >
Multi-Class Classification Tutorial with the Keras Deep ...
How to load data from CSV and make it available to Keras; How to prepare multi-class classification data for modeling with neural networks;...
Read more >
Multi-class semantic segmentation on heterogeneous labels
This paper discusses the effects of frugal labeling and proposes to train neural networks for multi-class semantic segmentation on ...
Read more >
Intersection over union (IOU) metric for multi-class semantic ...
Save this question. Show activity on this post. I have a semantic segmentation task to predict 5 channel mask using UNET for example...
Read more >
Multi Class Text Classification With Deep Learning Using BERT
When batching sequences together, we set return_attention_mask=True , so it will return the attention mask according to the specific ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found