question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <enum 'DefaultDataKeys'>.`

See original GitHub issue

🐛 Bug

It seems that there is some sort of state issue with custom preprocessors and flash modules. My custom preprocessor basic code is below, but it’s based on the examples given in the documentation and is very simple. I can’t get the error to appear consistently but it has happened at least once in my past few training iterations. to_onnx seems to trigger it as well, but not always, which makes me think it has something to do with the preprocessor state (current_transform perhaps).

I am happy to help debug this issue but it’s really annoying and I do need a custom preprocessor for my data.

To Reproduce

Steps to reproduce the behavior:

  1. Add a custom preprocessor to a data module initialized from a data source
  2. Add an example_input_array to the module
  3. Finetune model

Stack trace:

Traceback (most recent call last):
  File "xxx\dummy_run.py", line 35, in <module>
    run()
  File "xxx\dummy_run.py", line 31, in run
    train(_cfg_)
  File "xxx\train.py", line 78, in train
    test_model(trainer, had_error, checkpointer, data_module)
  File "xxx\test_model.py", line 18, in test_model
    raise had_error
  File "xxx\train.py", line 69, in train
    trainer.finetune(net, datamodule=data_module,
  File "xxx\env\lib\site-packages\flash\core\trainer.py", line 165, in finetune
    return super().fit(model, train_dataloader, val_dataloaders, datamodule)
  File "xxxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 553, in fit
    self._run(model)
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 918, in _run
    self._dispatch()
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 986, in _dispatch
    self.accelerator.start_training(self)
  File "xxx\overview.ai\tyson\env\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "xxx\env\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 996, in run_stage
    return self._run_train()
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1026, in _run_train
    self._pre_training_routine()
  File "xxx\env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1019, in _pre_training_routine
    ref_model.summarize(max_depth=max_depth)
  File "xxx\env\lib\site-packages\pytorch_lightning\core\lightning.py", line 1711, in summarize
    model_summary = ModelSummary(self, max_depth=max_depth)
  File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 215, in __init__
    self._layer_summary = self.summarize()
  File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 271, in summarize
    self._forward_example_input()
  File "xxx\env\lib\site-packages\pytorch_lightning\core\memory.py", line 288, in _forward_example_input
    input_ = model._apply_batch_transfer_handler(input_)
  File "xxx\env\lib\site-packages\pytorch_lightning\core\lightning.py", line 281, in _apply_batch_transfer_handler
    batch = self.transfer_batch_to_device(batch, device, dataloader_idx)
  File "xxx\env\lib\site-packages\flash\core\data\data_pipeline.py", line 609, in __call__
    outputs = additional_func(outputs)
  File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "xxx\env\lib\site-packages\flash\core\data\batch.py", line 239, in forward
    samples = self.per_batch_transform(samples)
  File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "xxx\env\lib\site-packages\flash\core\data\utils.py", line 178, in forward
    return self.func(*args, **kwargs)
  File "xxx\lib\site-packages\flash\core\data\process.py", line 409, in per_batch_transform_on_device
    return self.current_transform(batch)
  File "xxx\env\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "xxx\lib\site-packages\flash\core\data\transforms.py", line 40, in forward
    keys = list(filter(lambda key: key in x, self.keys))
  File "xxx\lib\site-packages\flash\core\data\transforms.py", line 40, in <lambda>
    keys = list(filter(lambda key: key in x, self.keys))
  File "xxx\env\lib\site-packages\torch\_tensor.py", line 670, in __contains__
    raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <enum 'DefaultDataKeys'>.

Code sample

from lightning.CustomDataSource import CustomDataSource 
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import DefaultPreprocess
from torchvision import transforms as T
import re
from argparse import Namespace
import numpy as np

from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.image.classification.transforms import default_transforms
from flash.core.data.data_source import DefaultDataKeys
from torchvision.transforms.functional import pil_to_tensor

class OverviewPreprocessor(DefaultPreprocess):
    def __init__(self, config):
        self.config = config
        img_size = (config['image_size'], config['image_size'])

        thing = ApplyToKeys(
            DefaultDataKeys.INPUT,
            T.Compose([
                T.Resize(config['image_size'])
            ])
        )

        train_transform = merge_transforms(
            default_transforms(img_size),
            {
                "pre_tensor_transform": thing,
                "post_tensor_transform": ApplyToKeys(
                    DefaultDataKeys.INPUT,
                    T.Compose(
                          T.RandomHorizontalFlip(),
                    )
                )
            }
        )

        tform= merge_transforms(
            default_transforms(img_size),
            {"post_tensor_transform": thing}
        )

        super().__init__(
            train_transform=train_transform,
            val_transform=tform,
            test_transform=tform,
            data_sources={
                "regression": CustomDataSource()
            },
            default_data_source="xxx",
        )

    @staticmethod
    def input_to_tensor(in_pil: np.ndarray):
        """Transform which creates a tensor from the given pil image and converts it to ``float``"""
        return pil_to_tensor(in_pil).float()

Expected behavior

Expect to_onnx or model.summarize to use the correct datatypes for inference

Environment

  • PyTorch Version (e.g., 1.0): 1.9
  • OS (e.g., Linux): Windows & Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Python version: 3.9
  • CUDA/cuDNN version: 11.1
  • GPU models and configuration: 3090 and A100
  • Lightning Flash version: 0.5
  • Lightning Version: 1.4.7
  • Any other relevant information: This only started happening once I finetuned or called to_onnx on a model with a custom preprocessor. It could be something else, but my gut says it’s probably the preprocess code.

Additional context

FWIW the Task that I’m using is very basic but initializes the example_input_array in the constructor like this:

self.example_input_array = torch.ones(
            (
                1,
                channels,
                config['image_size'],
                config['image_size']
            )
        )

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:11 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
dlangermcommented, Sep 28, 2021

BTW This is still an issue. For anyone else who runs into this, a hotfix is to override _apply_batch_transfer_handler in your Task.

    def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None) -> Any:
        if isinstance(batch, torch.Tensor):
            return super()._apply_batch_transfer_handler(batch={DefaultDataKeys.INPUT: batch}, device=device, dataloader_idx=dataloader_idx)[DefaultDataKeys.INPUT]
        else:
            return super()._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)

But I’m not comfortable making this a pull request as _ prefixed functions are generally not meant to be overridden per convention and I’m not exactly sure what other side effects this would have.

0reactions
stale[bot]commented, Nov 29, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Read more comments on GitHub >

github_iconTop Results From Across the Web

`RuntimeError: Tensor.__contains__ only supports ... - GitHub
RuntimeError : Tensor.__contains__ only supports Tensor or scalar, but you passed in a <enum 'DefaultDataKeys'>. #770.
Read more >
Source code for pyro.infer.traceenum_elbo
Tensor ) and scale.dim(): raise ValueError("enumeration only supports scalar poutine.scale") scales_set.add(float(scale)) if len(scales_set) !=
Read more >
The trouble with TypeScript enums - Thoughtbot
In TypeScript, enums have a few surprising limitations. In particular, it can be challenging to check whether or not a value is in...
Read more >
PyTorch Tensors Explained - Neural Network Programming
This creates an empty tensor (tensor with no data), but we'll get to adding data in just a moment. Tensor attributes. First, let's...
Read more >
Tensor Basics - PyTorch Beginner 02 - Python Engineer
This part covers the basics of Tensors and Tensor operations in PyTorch. Learn also how to convert from numpy data to PyTorch tensors...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found