Make kornia.augmentation.random_crop_generator safer
See original GitHub issue🚀 Feature
Make kornia.augmentation.random_crop_generator
safer.
Motivation
Within kornia.augmentation.random_crop_generator
_adapted_uniform
is used to generate random starting points for the crop. Since it returns a torch.float
, the upper delimiter (high
) is increased by one (high = x_diff + 1
) and cast to torch.long
, which is basically a floor division. While this generates a uniform distribution of the integers, it opens up an edge-case: what if Uniform.rsample
(used from _adapted_uniform
) returns exactly (or within the precision) high
?
from unittest import mock
import torch
from kornia.augmentation.random_generator import _adapted_uniform
batch_size = 1
input_size = (None, 2)
size = (None, 1)
same_on_batch = True
x_diff = input_size[1] - size[1]
with mock.patch(
"kornia.augmentation.utils.Uniform.rsample",
new=lambda self, shape: torch.ones(shape) * self.high,
):
x_start = _adapted_uniform((batch_size,), 0, x_diff + 1, same_on_batch).long()
print(f"x_start {'<=' if x_start <= x_diff else '>'} x_diff")
x_start
should lesser or equal to x_diff
, but as you can see
x_start > x_diff
Pitch
Although this edge case is unlikely, I think we should implement this properly. Especially since the fix is trivial:
from typing import Tuple, Union
def _adapted_uniform_int(
shape: Union[Tuple, torch.Size],
low: Union[float, torch.Tensor],
high: [float, torch.Tensor],
same_on_batch: bool = False,
) -> torch.Tensor:
return _adapted_uniform(shape, low, high + 1 - 1e-6, same_on_batch).int()
with mock.patch(
"kornia.augmentation.utils.Uniform.rsample",
new=lambda self, shape: torch.ones(shape) * self.high,
):
x_start = _adapted_uniform_int((batch_size,), 0, x_diff, same_on_batch).long()
print(f"x_start {'<=' if x_start <= x_diff else '>'} x_diff")
x_start <= x_diff
_adapted_uniform_int
has the same signature as _adapted_uniform
. The only difference to the above call is that we subtract a small constant from high + 1
to prevent it from being drawn exactly. As long as this constant 1e-6
is smaller than torch.finfo(torch.float).eps ~= 1.19e-7
this should be fine.
Additional Context
I’ve encountered this while working with kornia.augmentation.random_crop_generator
. If this used in other places throughout the code base the fix should be applicable everywhere.
Issue Analytics
- State:
- Created 3 years ago
- Comments:22 (11 by maintainers)
Top GitHub Comments
About the small numerical constant
1e-6
and the safety ofUniform
. I don’t think its as easy as we thought. Consider the following edge scenario:Since
.int()
is basically a floor division,rsample
indeed returns exactlyhigh
. Without further investigation I think this is due to numerical issues: any number that is drawn that is larger than1.0 - torch.finfo(torch.float).eps/2
will flip to1.0
.The extra
dtype
parameter seems fine. I think it could look like this:This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions, and happy coding day 😎