RescaleIntensity/Normalization IndexError: cannot do a non-empty take from an empty axes.
See original GitHub issueProblem
I have modified the U-net with pytorch lightning tutorial and I am getting IndexError: cannot do a non-empty take from an empty axes.
while using rescale intensity using LabelMap key as mask, without using it, the training works perfectly fine, but with it, I get this error.
Edit: I think this is caused by there being no positive value in the mask (LabelMap is all zeros)
To reproduce My preprocessing function:
def get_preprocessing_transform(self):
preprocess = tio.Compose([
tio.RescaleIntensity((-1, 1), masking_method = 'label'), # label is the key for LabelMap in subject (shown in below snippet)
tio.CropOrPad((64,512, 512)),
tio.OneHot(num_classes=4),
])
return preprocess
The way i load data:
def load_entire_dataset(self, filenames, mode):
# Input filenames from list of files
data_input = [i for i in filenames if 'img' in i]
subjects = []
for filename in data_input:
# Sanity check to see if file exists on disk
if filename not in list(os.listdir(self.root_dir)):
continue
subject = tio.Subject(
image=tio.ScalarImage(self.root_dir+filename), # input (1, 64, 512, 512) (C, D, H, W)
label=tio.LabelMap(self.root_dir+filename.replace('img', 'seg')) # Segmentation map (4 classes including bg)
)
subjects.append(subject)
if mode == 'train':
transform = tio.Compose([self.augment, self.preprocess]) # Augmentations from tio
else:
transform = self.preprocess
return tio.SubjectsDataset(subjects, transform)
I have taken most of the code from tutorial and modified it. I get this error when trainer runs validation for sanity check.
Expected behavior For it to run without issues, it works ok when i loop over the dataset outside of the trainer
Actual behavior No error
0/2 [00:00<?, ?it/s]
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-81-467575a3c590> in <module>()
1 start = datetime.now()
2 print('Training started at', start)
----> 3 trainer.fit(model=model, datamodule=data)
4 # trainer.fit(model=model, train_dataloader=train_dataloader_fn(), val_dataloaders=val_dataloader_fn())
5 print('Training duration:', datetime.now() - start)
23 frames
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
497
498 # dispath `start_training` or `start_testing` or `start_predicting`
--> 499 self.dispatch()
500
501 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
544
545 else:
--> 546 self.accelerator.start_training(self)
547
548 def train_or_test_or_predict(self):
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
71
72 def start_training(self, trainer):
---> 73 self.training_type_plugin.start_training(trainer)
74
75 def start_testing(self, trainer):
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
112 def start_training(self, trainer: 'Trainer') -> None:
113 # double dispatch to initiate the training loop
--> 114 self._results = trainer.run_train()
115
116 def start_testing(self, trainer: 'Trainer') -> None:
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
605 self.progress_bar_callback.disable()
606
--> 607 self.run_sanity_check(self.lightning_module)
608
609 # set stage for logging
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
862
863 # run eval step
--> 864 _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches)
865
866 self.on_sanity_check_end()
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, max_batches, on_epoch)
711 dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
712
--> 713 for batch_idx, batch in enumerate(dataloader):
714 if batch is None:
715 continue
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self)
519 if self._sampler_iter is None:
520 self._reset()
--> 521 data = self._next_data()
522 self._num_yielded += 1
523 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
559 def _next_data(self):
560 index = self._next_index() # may raise StopIteration
--> 561 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
562 if self._pin_memory:
563 data = _utils.pin_memory.pin_memory(data)
/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.7/dist-packages/torchio/data/dataset.py in __getitem__(self, index)
83 # Apply transform (this is usually the bottleneck)
84 if self._transform is not None:
---> 85 subject = self._transform(subject)
86 return subject
87
/usr/local/lib/python3.7/dist-packages/torchio/transforms/transform.py in __call__(self, data)
124 subject = copy.copy(subject)
125 with np.errstate(all='raise', under='ignore'):
--> 126 transformed = self.apply_transform(subject)
127 if self.keep is not None:
128 for name, image in images_to_keep.items():
/usr/local/lib/python3.7/dist-packages/torchio/transforms/augmentation/composition.py in apply_transform(self, subject)
45 def apply_transform(self, subject: Subject) -> Subject:
46 for transform in self.transforms:
---> 47 subject = transform(subject)
48 return subject
49
/usr/local/lib/python3.7/dist-packages/torchio/transforms/transform.py in __call__(self, data)
124 subject = copy.copy(subject)
125 with np.errstate(all='raise', under='ignore'):
--> 126 transformed = self.apply_transform(subject)
127 if self.keep is not None:
128 for name, image in images_to_keep.items():
/usr/local/lib/python3.7/dist-packages/torchio/transforms/preprocessing/intensity/normalization_transform.py in apply_transform(self, subject)
51 image.data,
52 )
---> 53 self.apply_normalization(subject, image_name, mask)
54 return subject
55
/usr/local/lib/python3.7/dist-packages/torchio/transforms/preprocessing/intensity/rescale.py in apply_normalization(self, subject, image_name, mask)
66 ) -> None:
67 image = subject[image_name]
---> 68 image.set_data(self.rescale(image.data, mask, image_name))
69
70 def rescale(
/usr/local/lib/python3.7/dist-packages/torchio/transforms/preprocessing/intensity/rescale.py in rescale(self, tensor, mask, image_name)
78 mask = mask.numpy()
79 values = array[mask]
---> 80 cutoff = np.percentile(values, self.percentiles)
81 np.clip(array, *cutoff, out=array)
82 if self.in_min_max is None:
<__array_function__ internals> in percentile(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py in percentile(a, q, axis, out, overwrite_input, interpolation, keepdims)
3731 raise ValueError("Percentiles must be in the range [0, 100]")
3732 return _quantile_unchecked(
-> 3733 a, q, axis, out, overwrite_input, interpolation, keepdims)
3734
3735
/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py in _quantile_unchecked(a, q, axis, out, overwrite_input, interpolation, keepdims)
3851 r, k = _ureduce(a, func=_quantile_ureduce_func, q=q, axis=axis, out=out,
3852 overwrite_input=overwrite_input,
-> 3853 interpolation=interpolation)
3854 if keepdims:
3855 return r.reshape(q.shape + k)
/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py in _ureduce(a, func, **kwargs)
3427 keepdim = (1,) * a.ndim
3428
-> 3429 r = func(a, **kwargs)
3430 return r, keepdim
3431
/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py in _quantile_ureduce_func(a, q, axis, out, overwrite_input, interpolation, keepdims)
3965 n = np.isnan(ap[-1:, ...])
3966
-> 3967 x1 = take(ap, indices_below, axis=axis) * weights_below
3968 x2 = take(ap, indices_above, axis=axis) * weights_above
3969
<__array_function__ internals> in take(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in take(a, indices, axis, out, mode)
189 [5, 7]])
190 """
--> 191 return _wrapfunc(a, 'take', indices, axis=axis, out=out, mode=mode)
192
193
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
56
57 try:
---> 58 return bound(*args, **kwds)
59 except TypeError:
60 # A TypeError occurs if the object does have such a method in its
IndexError: cannot do a non-empty take from an empty axes.
System info
Output of python <(curl -s https://raw.githubusercontent.com/fepegar/torchio/master/print_system.py)
:
(I am using google colab)
Platform: Linux-5.4.104+-x86_64-with-Ubuntu-18.04-bionic
TorchIO: 0.18.53
PyTorch: 1.9.0+cu102
SimpleITK: 2.1.0 (ITK 5.2)
NumPy: 1.19.5
Python: 3.7.11 (default, Jul 3 2021, 18:01:19)
[GCC 7.5.0]
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (4 by maintainers)
Top GitHub Comments
Fixed in
v0.18.54
👍Thanks again!
I think you might have misunderstood the usage of the
masking_method
argument. You seem to want to use a function to compute the statistics for the normalization, but in the first example you are returning a tensor computed from the global variablemask_tensor
, and in the second example you are just returning an empty tensor. Maybe you meanmasking_method=lambda x: x == 1
for the first example.That being said, I agree this shouldn’t result in an error, and the message should be more informative. I will commit a fix in a second.