Adding `Dataset` and `DataLoader`-like functionality in KerasCV
See original GitHub issue@sayakpaul @LukeWood @ianjjohnson
Dataset loading and preprocessing in vision is generally a messy ordeal and the additional boilerplate code for TFRecords makes it worse. A probable solution for this can be to have classes which consolidate all preprocessing logic. Another set of classes can manage data augmentation. Concretely, I propose to have two kinds of classes: one for preprocessing and one for augmentation.
I understand that the code is a bit crude, but the idea is to provide an interface for users to consolidate all of their preprocessing and augmentation logic in one (or two) classes. We can extend this by providing frequently used APIs like TFRecordLoader
and ImageDirectoryLoader
.
Sample code:
class DataLoader:
"""
This class will _create_ a `tf.data.Dataset` instance.
"""
def __init__(self, source):
self.source = source # source can be anything that user wants. For example, it can be path to a directory, or a list of paths to TFRecords, etc.
def get_dataset(self):
# extract data from the source in whichever way necessary
# For example, logic to read TFRecords can be written here.
raise NotImplementedError
class DataAugmenter: # Open to suggestions for a better name :)
"""
This class will _consume_ a `tf.data.Dataset` instance.
"""
def __init__(self, dataset): # tf.data.Dataset object
self.dataset = dataset
def augment(self, example):
# all augmentation logic
raise NotImplementedError
def get_dataset(self, batch_size):
self.dataset = self.dataset.map(self.augment)
self.dataset = self.dataset.prefetch(AUTO)
self.dataset = self.dataset.batch(batch_size, AUTO)
return self.dataset
class TFRecordLoader(DataLoader): #
def __init__(self, list_of_tfrecord_paths, tfrecords_format):
super().__init__(source=list_of_tfrecord_paths)
self.tfrecords_format = tfrecords_format
def get_dataset(self): # one of the APIs that we can provide
files = tf.data.Dataset.list_files(self.tfrecs_filepath)
ds = files.interleave(
tf.data.TFRecordDataset, num_parallel_calls=AUTO, deterministic=False
)
ds = ds.map(self.decode_example_fn, num_parallel_calls=AUTO) # decode_example_fn omitted here for the sake of brevity
ds = ds.prefetch(AUTO)
return ds
Please see https://github.com/keras-team/keras-cv/issues/78#issuecomment-1070468749.
Issue Analytics
- State:
- Created a year ago
- Comments:7 (4 by maintainers)
Top GitHub Comments
For now let’s hold off on this and not expand our API too much. We have tf.data. We can do some loading like keras.datasets does but let’s keep it minimal!
Cool, I’m beginning to flesh out the details of what I think we want to offer.