Size mismatch occurs in UNet model at 5th stage
See original GitHub issueI used the SMP library to create a UNet model with the following configurations:
model = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet', in_channels=3, classes=30)
However, I have also tried with other encoders (including the default resnet34) and the error seems to appear for every encoder that I choose. I am training it on a custom dataset of which the dimensions of the images are: w=320, h=192
My code runs fine until one of the final steps in the decoder block. The error traces back to smp/unet/decoder.py
. When I’m running a training epoch, the error occurs in def forward(self, x, skip=None)
of decoder.py
def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.attention2(x)
return x
For the first steps, everything runs fine and the dimensions of ‘x’ match with ‘skip’. Below you can find a list of the dimensions of both x and skip as I go through the decoder:
STEP 1
x.shape
Out[1]: torch.Size([1, 2048, 14, 20])
skip.shape
Out[2]: torch.Size([1, 1024, 14, 20])
STEP 2
x.shape
Out[3]: torch.Size([1, 256, 28, 40])
skip.shape
Out[4]: torch.Size([1, 512, 28, 40])
STEP 3
x.shape
Out[5]: torch.Size([1, 128, 56, 80])
skip.shape
Out[6]: torch.Size([1, 256, 55, 80])
STEP 4
x.shape
Out[7]: torch.Size([1, 128, 56, 80])
skip.shape
Out[8]: torch.Size([1, 256, 55, 80])
STEP 5
x.shape
Out[9]: torch.Size([1, 3, 192, 320])
skip.shape
Out[10]: torch.Size([1, 256, 55, 80])
Around step 3, a mismatch between the tensors starts occurring which causes the error. This error traceback can be seen in the indented block below. What I find weird about this, is that I have used the exact same codebase with a different dataset that only consisted of 6 classes and in that case there was no issue. I am also unsure where this is happening as I cannot seem to find the root cause.
Traceback
(most recent call last): File “/Users/fc/Desktop/ct/segmentation_code/main.py”, line 141, in <module> trainer.train() File “/Users/fc/Desktop/ct/segmentation_code/ops/trainer.py”, line 44, in train self.train_logs = self.train_epoch.run(self.trainloader) File “/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/utils/train.py”, line 47, in run loss, y_pred = self.batch_update(x, y) File “/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/utils/train.py”, line 87, in batch_update prediction = self.model.forward(x) File “/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/base/model.py”, line 16, in forward decoder_output = self.decoder(*features) File “/Users/fc/miniconda3/envs/ct/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl return forward_call(*input, **kwargs) File “/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/unet/decoder.py”, line 119, in forward x = decoder_block(x, skip) File “/Users/fc/miniconda3/envs/ct/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl return forward_call(*input, **kwargs) File “/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/unet/decoder.py”, line 38, in forward x = torch.cat([x, skip], dim=1) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 56 but got size 55 for tensor number 1 in the list.
Issue Analytics
- State:
- Created 2 years ago
- Comments:12 (7 by maintainers)
@JulienMaille yes, it is taken into consideration
not 32 but 2^depth, you can relax the constraing if using 3 or 4 stages instead of 5