Other datatype for LabelMap than float32
See original GitHub issue🚀 Feature I noticed that a LabelMap and an IntensityImage are both saved as float32 tensors, which means that the LabelMap uses a lot more memory than needed. This is because this piece of code in io.py which casts all image to float32:
def _read_sitk(path: TypePath) -> Tuple[torch.Tensor, np.ndarray]:
if Path(path).is_dir(): # assume DICOM
image = _read_dicom(path)
else:
image = sitk.ReadImage(str(path))
data, affine = sitk_to_nib(image, keepdim=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
tensor = torch.from_numpy(data)
return tensor, affine
Is there a reason for this .astype(np.float32)?
This can be made a lot more memory friendly by removing this cast and storing segmentations in memory as uint8 for example. Also I expect spatial augmentations which requires resampling to be a lot faster when they work with uint8 instead of float32
Motivation
- Better use of memory
- Faster augmentations which require resampling
Pitch
No cast to float32 for all tensors, allowing different dtypes
Could these two lines be removed? All tests still pass when I comment them out. Maybe only cast bool to np.uint8 because SimpleITK does not support bool?
if data.dtype != np.float32:
data = data.astype(np.float32)
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:11 (11 by maintainers)
Top GitHub Comments
@romainVala I think the proposal is not really forcing a specific type, but stopping forcing everything to be float 32. So your partial volume maps (which maybe shouldn’t be instantiated as a label map, as they don’t contain categorical labels) would still be processed fine.
I just tried this code
And then these:
So you’re right, it’s faster in uint8. I did this because some transforms required float and so I just transformed everything to float. Another reason is that having a consistent data type, everything works smoothly with a data loader:
The tests probably pass because they typically use images that are created in float 32 (and obviously because they’re not complete enough).
I agree that saving in float by default is not good. There should be at least a kwarg for the dtype.
So what do you think? I suppose there could be a
Cast
transform that could be used before a data loader (and by transforms that need float) but this would be quite backwards incompatible. But if it makes the library way faster, might be a good thing to do.