question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Support for IterableDatasets

See original GitHub issue

🚀 Feature

Pytorch introduced the IterableDataset in v1.2 to allow users to process streams of information. ClassyVision currently only supports map style datasets, it would be nice to extend the support to IterableDatasets, given that they are especially useful to process video streams.

Motivation / Pitch

Map style datasets assume each sample can be read completely independently from each other. In some situations, such as processing video streams, it is extremely expensive to open and close a video stream N times to read N frames. IterableDatasets allow streams to be open once and then results are yielded as requested by the data loader, which is substantially more efficient.

Additional context

There are a few important differences between Map style datasets and iterable datasets, that break the current classy vision dataset paradigm:

  1. In an iterable dataset, there is no __getitem__ method, this is replaced by the __iter__ method
  2. The __len__ method is optional in iterable dataset
  3. Samplers do not work with IterableDatasets, sampling and shuffling has to be handled in the dataset

I’ve come up with a template dataset called ChunkDataset hides some of this complexity away, which might be nice to help beginner users to get started. Nevertheless, in order to get this working with my code, I had to subclass ClassificationTask to modify it and had to create and entirely new base class for this style of dataset (ClassyDataset is not compatible).

chunk.py

class ChunkDataset(IterableDataset):
    def __init__(self, indices: List, process_fn: Callable):
        """
        Subset of IterableDataset to serve as a base class to process streams of data,
        such as audio, video or text.

        Args:
            indices (list): list of arguments to provide to process_fn
            process_fn (callable): function that processes the indices and returns an iterator
        """
        self.idxs = indices
        self._process_fn = process_fn
        self.epoch = 0
        self.shuffle = False

        # replacement for distributed sampler
        distributed = dist.is_available() and dist.is_initialized()
        if distributed:
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()
            self.idxs = self.idxs[rank::num_replicas]

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        if self.shuffle:
            indices = torch.randperm(len(self.idxs), generator=g).tolist()
            idxs = [self.idxs[i] for i in indices]
        else:
            idxs = self.idxs

        return self._process_fn(idxs)

    @staticmethod
    def worker_init_fn(worker_id):
        worker_info = torch.utils.data.get_worker_info()
        dataset = worker_info.dataset  # the dataset copy in this worker process
        n_workers = worker_info.num_workers
        dataset.idxs = dataset.idxs[worker_id::n_workers]

    def set_epoch(self, epoch):
        self.epoch = epoch
        return self

    def set_shuffle(self, shuffle: bool = True):
        self.shuffle = shuffle
        return self

    def __len__(self):
        raise NotImplementedError


class ClassyChunkDataset(IterableDataset):
    """
    Class representing a dataset abstraction to wrap a ChunkDataset.

    This class wraps a :class:`ChunkDataset` via the `dataset` attribute
    and configures the dataloaders needed to access the datasets.
    Transforms which need to be applied to the data should be specified in this class.
    ClassyChunkDataset can be instantiated from a configuration file as well.
    """

    def __init__(
        self,
        dataset: ChunkDataset,
        batchsize_per_replica: int,
        shuffle: bool,
        transform: Optional[Union[ClassyTransform, Callable]],
    ) -> None:
        """
        Constructor for a ClassyDataset.

        Args:
            batchsize_per_replica: Positive integer indicating batch size for each
                replica
            shuffle: Whether to shuffle between epochs
            transform: When set, transform to be applied to each sample
            num_samples: When set, this restricts the number of samples provided by
                the dataset
        """
        # Asserts:
        assert is_pos_int(
            batchsize_per_replica
        ), "batchsize_per_replica must be a positive int"
        assert isinstance(shuffle, bool), "shuffle must be a boolean"

        # Assignments:
        self.batchsize_per_replica = batchsize_per_replica
        self.shuffle = shuffle
        self.transform = transform
        self.dataset = dataset

        if self.shuffle:
            self.dataset = self.dataset.set_shuffle()

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "ClassyDataset":
        """Instantiates a ClassyDataset from a configuration.

        Args:
            config: A configuration for the ClassyDataset.

        Returns:
            A ClassyDataset instance.
        """
        raise NotImplementedError

    @classmethod
    def parse_config(cls, config: Dict[str, Any]):
        """
        This function parses out common config options.

        Args:
            config: A dict with the following string keys -

                | *batchsize_per_replica* (int): Must be a positive int, batch size
                |    for each replica
                | *use_shuffle* (bool): Whether to enable shuffling for the dataset
                | *num_samples* (int, optional): When set, restricts the number of
                     samples in a dataset
                | *transforms*: list of tranform configurations to be applied in order

        Returns:
            A tuple containing the following variables -
                | *transform_config*: Config for the dataset transform. Can be passed to
                |    :func:`transforms.build_transform`
                | *batchsize_per_replica*: Batch size per replica
                | *shuffle*: Whether we should shuffle between epochs
                | *num_samples*: When set, restricts the number of samples in a dataset
        """
        batchsize_per_replica = config.get("batchsize_per_replica")
        shuffle = config.get("use_shuffle")
        num_samples = config.get("num_samples")
        transform_config = config.get("transforms")
        return transform_config, batchsize_per_replica, shuffle, num_samples

    def __iter__(self):
        for sample in self.dataset:
            if self.transform is not None:
                sample = self.transform(sample)
            yield sample

    def __len__(self):
        return len(self.dataset)

    def iterator(self, *args, **kwargs):
        """
        Returns an iterable which can be used to iterate over the data.

        Args:
            shuffle_seed (int, optional): Seed for the shuffle
            current_phase_id (int, optional): The epoch being fetched. Needed so that
                each epoch has a different shuffle order
        Returns:
            An iterable over the data
        """
        # TODO: Fix naming to be consistent (i.e. everyone uses epoch)
        epoch = kwargs.get("current_phase_id", 0)
        assert isinstance(epoch, int), "Epoch must be an int"

        self.dataset = self.dataset.set_epoch(epoch)

        return DataLoader(
            self,
            batch_size=self.batchsize_per_replica,
            num_workers=kwargs.get("num_workers", 0),
            pin_memory=kwargs.get("pin_memory", False),
            multiprocessing_context=kwargs.get("multiprocessing_context", None),
            worker_init_fn=self.worker_init_fn
        )

    def get_batchsize_per_replica(self):
        """
        Get the batch size per replica.

        Returns:
            The batch size for each replica.
        """
        return self.batchsize_per_replica

    def get_global_batchsize(self):
        """
        Get the global batch size, combined over all the replicas.

        Returns:
            The overall batch size of the dataset.
        """
        return self.get_batchsize_per_replica() * get_world_size()

    @staticmethod
    def worker_init_fn(worker_id):
        worker_info = torch.utils.data.get_worker_info()
        dataset = worker_info.dataset.dataset  # the dataset copy in this worker process
        n_workers = worker_info.num_workers
        dataset.idxs = dataset.idxs[worker_id::n_workers]

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
miguelvrcommented, Mar 25, 2020

indeed, closing the issue

0reactions
mannatsinghcommented, Mar 25, 2020

@miguelvr I think https://github.com/facebookresearch/ClassyVision/pull/455 should solve this issue. We should still document this in a tutorial - I’ll create a separate issue for that. Please let me know if you are fine with closing this issue out.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Trainer: support iterable datasets for evaluation #9130 - GitHub
The trainer seems to support passing iterable datasets as the train_dataset (see #5829) but misses to support the same for the eval_dataset.
Read more >
Iterable datasets features - Hugging Face Forums
I'm using dataset with streaming=True and I see the dataset features are None . It's a expected behaviour? import datasets ds = datasets....
Read more >
Fault-tolerant Training (FAQ) - PyTorch Lightning
Fault-tolerant Training (FAQ). How do I use iterable datasets? To support fault-tolerance, you will need to use and expose a sampler within your...
Read more >
Using iterable datasets - PyTorch Forums
Hello, Im trying to set up a processes to deal with streaming sensor data. I understand this is the exact usecase for iterable...
Read more >
monai.data.iterable_dataset — MONAI 0.7.0 Documentation
[docs]class IterableDataset(_TorchIterableDataset): """ A generic dataset for iterable data source and an optional callable data transform when fetching a ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found