`RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <enum 'DefaultDataKeys'>.`
See original GitHub issue🐛 Bug
It seems that there is some sort of state issue with custom preprocessors and flash modules. My custom preprocessor basic code is below, but it’s based on the examples given in the documentation and is very simple. I can’t get the error to appear consistently but it has happened at least once in my past few training iterations. to_onnx seems to trigger it as well, but not always, which makes me think it has something to do with the preprocessor state (current_transform perhaps).
I am happy to help debug this issue but it’s really annoying and I do need a custom preprocessor for my data.
To Reproduce
Steps to reproduce the behavior:
- Add a custom preprocessor to a data module initialized from a data source
- Add an example_input_array to the module
- Finetune model
Stack trace:
Traceback (most recent call last):
File "xxx\dummy_run.py", line 35, in <module>
run()
File "xxx\dummy_run.py", line 31, in run
train(_cfg_)
File "xxx\train.py", line 78, in train
test_model(trainer, had_error, checkpointer, data_module)
File "xxx\test_model.py", line 18, in test_model
raise had_error
File "xxx\train.py", line 69, in train
trainer.finetune(net, datamodule=data_module,
File "xxx\env\lib\site-packages\flash\core\trainer.py", line 165, in finetune
return super().fit(model, train_dataloader, val_dataloaders, datamodule)
File "xxxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 553, in fit
self._run(model)
File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 918, in _run
self._dispatch()
File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 986, in _dispatch
self.accelerator.start_training(self)
File "xxx\overview.ai\tyson\env\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 92, in start_training
self.training_type_plugin.start_training(trainer)
File "xxx\env\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 996, in run_stage
return self._run_train()
File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1026, in _run_train
self._pre_training_routine()
File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1019, in _pre_training_routine
ref_model.summarize(max_depth=max_depth)
File "xxx\env\lib\site-packages\pytorch_lightning\core\lightning.py", line 1711, in summarize
model_summary = ModelSummary(self, max_depth=max_depth)
File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 215, in __init__
self._layer_summary = self.summarize()
File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 271, in summarize
self._forward_example_input()
File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 288, in _forward_example_input
input_ = model._apply_batch_transfer_handler(input_)
File "xxx\env\lib\site-packages\pytorch_lightning\core\lightning.py", line 281, in _apply_batch_transfer_handler
batch = self.transfer_batch_to_device(batch, device, dataloader_idx)
File "xxx\env\lib\site-packages\flash\core\data\data_pipeline.py", line 609, in __call__
outputs = additional_func(outputs)
File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "xxx\env\lib\site-packages\flash\core\data\batch.py", line 239, in forward
samples = self.per_batch_transform(samples)
File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "xxx\env\lib\site-packages\flash\core\data\utils.py", line 178, in forward
return self.func(*args, **kwargs)
File "xxx\lib\site-packages\flash\core\data\process.py", line 409, in per_batch_transform_on_device
return self.current_transform(batch)
File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "xxx\lib\site-packages\flash\core\data\transforms.py", line 40, in forward
keys = list(filter(lambda key: key in x, self.keys))
File "xxx\lib\site-packages\flash\core\data\transforms.py", line 40, in <lambda>
keys = list(filter(lambda key: key in x, self.keys))
File "xxx\env\lib\site-packages\torch\_tensor.py", line 670, in __contains__
raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <enum 'DefaultDataKeys'>.
Code sample
from lightning.CustomDataSource import CustomDataSource
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import DefaultPreprocess
from torchvision import transforms as T
import re
from argparse import Namespace
import numpy as np
from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.image.classification.transforms import default_transforms
from flash.core.data.data_source import DefaultDataKeys
from torchvision.transforms.functional import pil_to_tensor
class OverviewPreprocessor(DefaultPreprocess):
def __init__(self, config):
self.config = config
img_size = (config['image_size'], config['image_size'])
thing = ApplyToKeys(
DefaultDataKeys.INPUT,
T.Compose([
T.Resize(config['image_size'])
])
)
train_transform = merge_transforms(
default_transforms(img_size),
{
"pre_tensor_transform": thing,
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
T.Compose(
T.RandomHorizontalFlip(),
)
)
}
)
tform= merge_transforms(
default_transforms(img_size),
{"post_tensor_transform": thing}
)
super().__init__(
train_transform=train_transform,
val_transform=tform,
test_transform=tform,
data_sources={
"regression": CustomDataSource()
},
default_data_source="xxx",
)
@staticmethod
def input_to_tensor(in_pil: np.ndarray):
"""Transform which creates a tensor from the given pil image and converts it to ``float``"""
return pil_to_tensor(in_pil).float()
Expected behavior
Expect to_onnx or model.summarize to use the correct datatypes for inference
Environment
- PyTorch Version (e.g., 1.0): 1.9
- OS (e.g., Linux): Windows & Linux
- How you installed PyTorch (
conda,pip, source): conda - Python version: 3.9
- CUDA/cuDNN version: 11.1
- GPU models and configuration: 3090 and A100
- Lightning Flash version: 0.5
- Lightning Version: 1.4.7
- Any other relevant information: This only started happening once I finetuned or called
to_onnxon a model with a custom preprocessor. It could be something else, but my gut says it’s probably the preprocess code.
Additional context
FWIW the Task that I’m using is very basic but initializes the example_input_array in the constructor like this:
self.example_input_array = torch.ones(
(
1,
channels,
config['image_size'],
config['image_size']
)
)
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (10 by maintainers)

Top Related StackOverflow Question
BTW This is still an issue. For anyone else who runs into this, a hotfix is to override
_apply_batch_transfer_handlerin your Task.But I’m not comfortable making this a pull request as
_prefixed functions are generally not meant to be overridden per convention and I’m not exactly sure what other side effects this would have.This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.