Support for IterableDatasets
See original GitHub issue🚀 Feature
Pytorch introduced the IterableDataset in v1.2 to allow users to process streams of information. ClassyVision currently only supports map style datasets, it would be nice to extend the support to IterableDatasets, given that they are especially useful to process video streams.
Motivation / Pitch
Map style datasets assume each sample can be read completely independently from each other. In some situations, such as processing video streams, it is extremely expensive to open and close a video stream N times to read N frames. IterableDatasets allow streams to be open once and then results are yielded as requested by the data loader, which is substantially more efficient.
Additional context
There are a few important differences between Map style datasets and iterable datasets, that break the current classy vision dataset paradigm:
- In an iterable dataset, there is no
__getitem__
method, this is replaced by the__iter__
method - The
__len__
method is optional in iterable dataset - Samplers do not work with IterableDatasets, sampling and shuffling has to be handled in the dataset
I’ve come up with a template dataset called ChunkDataset
hides some of this complexity away, which might be nice to help beginner users to get started. Nevertheless, in order to get this working with my code, I had to subclass ClassificationTask
to modify it and had to create and entirely new base class for this style of dataset (ClassyDataset is not compatible).
chunk.py
class ChunkDataset(IterableDataset):
def __init__(self, indices: List, process_fn: Callable):
"""
Subset of IterableDataset to serve as a base class to process streams of data,
such as audio, video or text.
Args:
indices (list): list of arguments to provide to process_fn
process_fn (callable): function that processes the indices and returns an iterator
"""
self.idxs = indices
self._process_fn = process_fn
self.epoch = 0
self.shuffle = False
# replacement for distributed sampler
distributed = dist.is_available() and dist.is_initialized()
if distributed:
num_replicas = dist.get_world_size()
rank = dist.get_rank()
self.idxs = self.idxs[rank::num_replicas]
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.idxs), generator=g).tolist()
idxs = [self.idxs[i] for i in indices]
else:
idxs = self.idxs
return self._process_fn(idxs)
@staticmethod
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset # the dataset copy in this worker process
n_workers = worker_info.num_workers
dataset.idxs = dataset.idxs[worker_id::n_workers]
def set_epoch(self, epoch):
self.epoch = epoch
return self
def set_shuffle(self, shuffle: bool = True):
self.shuffle = shuffle
return self
def __len__(self):
raise NotImplementedError
class ClassyChunkDataset(IterableDataset):
"""
Class representing a dataset abstraction to wrap a ChunkDataset.
This class wraps a :class:`ChunkDataset` via the `dataset` attribute
and configures the dataloaders needed to access the datasets.
Transforms which need to be applied to the data should be specified in this class.
ClassyChunkDataset can be instantiated from a configuration file as well.
"""
def __init__(
self,
dataset: ChunkDataset,
batchsize_per_replica: int,
shuffle: bool,
transform: Optional[Union[ClassyTransform, Callable]],
) -> None:
"""
Constructor for a ClassyDataset.
Args:
batchsize_per_replica: Positive integer indicating batch size for each
replica
shuffle: Whether to shuffle between epochs
transform: When set, transform to be applied to each sample
num_samples: When set, this restricts the number of samples provided by
the dataset
"""
# Asserts:
assert is_pos_int(
batchsize_per_replica
), "batchsize_per_replica must be a positive int"
assert isinstance(shuffle, bool), "shuffle must be a boolean"
# Assignments:
self.batchsize_per_replica = batchsize_per_replica
self.shuffle = shuffle
self.transform = transform
self.dataset = dataset
if self.shuffle:
self.dataset = self.dataset.set_shuffle()
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ClassyDataset":
"""Instantiates a ClassyDataset from a configuration.
Args:
config: A configuration for the ClassyDataset.
Returns:
A ClassyDataset instance.
"""
raise NotImplementedError
@classmethod
def parse_config(cls, config: Dict[str, Any]):
"""
This function parses out common config options.
Args:
config: A dict with the following string keys -
| *batchsize_per_replica* (int): Must be a positive int, batch size
| for each replica
| *use_shuffle* (bool): Whether to enable shuffling for the dataset
| *num_samples* (int, optional): When set, restricts the number of
samples in a dataset
| *transforms*: list of tranform configurations to be applied in order
Returns:
A tuple containing the following variables -
| *transform_config*: Config for the dataset transform. Can be passed to
| :func:`transforms.build_transform`
| *batchsize_per_replica*: Batch size per replica
| *shuffle*: Whether we should shuffle between epochs
| *num_samples*: When set, restricts the number of samples in a dataset
"""
batchsize_per_replica = config.get("batchsize_per_replica")
shuffle = config.get("use_shuffle")
num_samples = config.get("num_samples")
transform_config = config.get("transforms")
return transform_config, batchsize_per_replica, shuffle, num_samples
def __iter__(self):
for sample in self.dataset:
if self.transform is not None:
sample = self.transform(sample)
yield sample
def __len__(self):
return len(self.dataset)
def iterator(self, *args, **kwargs):
"""
Returns an iterable which can be used to iterate over the data.
Args:
shuffle_seed (int, optional): Seed for the shuffle
current_phase_id (int, optional): The epoch being fetched. Needed so that
each epoch has a different shuffle order
Returns:
An iterable over the data
"""
# TODO: Fix naming to be consistent (i.e. everyone uses epoch)
epoch = kwargs.get("current_phase_id", 0)
assert isinstance(epoch, int), "Epoch must be an int"
self.dataset = self.dataset.set_epoch(epoch)
return DataLoader(
self,
batch_size=self.batchsize_per_replica,
num_workers=kwargs.get("num_workers", 0),
pin_memory=kwargs.get("pin_memory", False),
multiprocessing_context=kwargs.get("multiprocessing_context", None),
worker_init_fn=self.worker_init_fn
)
def get_batchsize_per_replica(self):
"""
Get the batch size per replica.
Returns:
The batch size for each replica.
"""
return self.batchsize_per_replica
def get_global_batchsize(self):
"""
Get the global batch size, combined over all the replicas.
Returns:
The overall batch size of the dataset.
"""
return self.get_batchsize_per_replica() * get_world_size()
@staticmethod
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset.dataset # the dataset copy in this worker process
n_workers = worker_info.num_workers
dataset.idxs = dataset.idxs[worker_id::n_workers]
Issue Analytics
- State:
- Created 4 years ago
- Comments:8 (8 by maintainers)
Top GitHub Comments
indeed, closing the issue
@miguelvr I think https://github.com/facebookresearch/ClassyVision/pull/455 should solve this issue. We should still document this in a tutorial - I’ll create a separate issue for that. Please let me know if you are fine with closing this issue out.