Unet: Handle arbitrarily sized input images.
See original GitHub issueCurrently the Unet model doesn’t handle arbitrary input image sizes.
For example: Here the spatial dimensions of the output don’t match the input
>>> import segmentation_models_pytorch as smp
>>> model = smp.Unet()
>>> img = torch.rand(1, 3, 127, 127)
>>> model(img).shape
torch.Size([1, 1, 128, 128]) # Spatial dims don't match input
And this input throws an error
>>> img = torch.rand(1, 3, 129, 129)
>>> model(img).shape
RuntimeError Traceback (most recent call last)
...
...
.../decoder.py in forward(self, x, skip)
36 x = F.interpolate(x, scale_factor=2, mode="nearest")
37 if skip is not None:
---> 38 x = torch.cat([x, skip], dim=1)
39 x = self.attention1(x)
40 x = self.conv1(x)
RuntimeError: torch.cat(): Sizes of tensors must match except in dimension 1. Got 10 and 9 in dimension 2 (The offending index is 1)
The Problem:
When the spatial dims of the input to an encoder downsampling layer isn’t a divisible by 2, there is a round-off of 1 pixel. When this feature is upsampled in the decoder using F.interpolate
, the spatial dims won’t match the skip connection.
The Fix:
The fix here is to replace F.interpolate
with an upsampling layer that takes an additional (optional) argument output_size
. It could look as follows:
class Upsample
def __init__(...)
...
def forward(self, input, output_size)
upsampled = F.interpolate(...)
if upsampled.shape != output_size:
# pad accordingly
return upsampled
This is however a breaking change for any users (probably near zero) that rely on the current behavior (i.e when the output size doesn’t match the input size)
I’m happy to create a pull request for this if it is agreed that this fix should be added
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (1 by maintainers)
Hi @jenkspt Thanks a lot for the suggestion, it may be a great feature! However, as far as I know, interpolation with
scale_factor
is more friendly with exporting to other formats. It is a long time since I heard about these issues, maybe now ONNX and jit trace are supporting dynamic shapes? It would be nice to investigate this question before changing current behavior.This issue was closed because it has been stalled for 7 days with no activity.