question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[FEATURE] Improve PyTorch-user adoption by better aligning with expected behaviours of Dataset and DataLoader

See original GitHub issue

🚨🚨 Feature Request

  • Related to an existing Issue
  • A new implementation (Improvement, Extension)

Is your feature request related to a problem?

The current way the hub implements PyTorch datasets/dataloaders looks like the following:

dataset = hub.load(self.dataset_path).pytorch()

where .pytorch’s signature looks like so:

@hub_reporter.record_call
    def pytorch(
        self,
        transform: Optional[Callable] = None,
        tensors: Optional[Sequence[str]] = None,
        num_workers: int = 1,
        batch_size: int = 1,
        drop_last: bool = False,
        collate_fn: Optional[Callable] = None,
        pin_memory: bool = False,
        shuffle: bool = False,
        buffer_size: int = 2048,
        use_local_cache: bool = False,
        use_progress_bar: bool = False,
    ):

The issue I have is that most people using PyTorch have implemented their Dataset and DataLoaders as per the standard guidelines. That means that a dataset looks like this:

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

And a DataLoader looks like this:

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

The main issue stems from the fact that most if not all of PyTorch’s existing codebase’s pass transforms to the Dataset class, and NOT the DataLoader class.

The fact that hub passes transforms to the DataLoader class directly will cause a lot of friction between existing Datasets and their transfer to activeloop’s hub, or, in my case, to be able to have an ecosystem where hub datasets and other PyTorch datasets live in harmony.

Here’s how I am currently porting activeloop datasets into a format that conforms with the standard way that Pytorch datasets are built. I hope it can help in some way.

import pathlib
from multiprocessing import cpu_count
from typing import Union, Optional, Callable

import hub
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from gate.base.utils.loggers import get_logger

log = get_logger(__name__, set_default_handler=True)


class ImageNetClassificationDataset(Dataset):
    def __init__(
        self,
        dataset_root: Union[str, pathlib.Path],
        set_name: str,
        input_transform: Optional[Callable],
        download: bool = True,
        target_transform: Optional[Callable] = None,
    ):
        super().__init__()
        
        self.dataset_path = f"hub://activeloop/imagenet-{set_name}"
        if download:
            hub.copy(
                src=self.dataset_path,
                dest=dataset_root,
                overwrite=True,
                num_workers=cpu_count(),
            )
            self.dataset = hub.load(dataset_root)
        else:
            self.dataset = hub.load(self.dataset_path)
        log.info(f"Loaded dataset with {self.dataset} items")
        self.download = download
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        item = self.dataset[index]
        x, y_labels, y_bounding_boxes = (
            torch.Tensor(item["images"].numpy()),
            torch.Tensor(item["labels"].numpy().astype(int)),
            torch.Tensor(item["boxes"].numpy()),
        )
        pilify = transforms.ToPILImage()
        x = pilify(x.permute([2, 0, 1]))
        x = self.input_transform(x)

        if self.target_transform is not None:
            y_labels = self.target_transform(y_labels)

        return {"image": x}, {"image": y_labels}

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:9 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
davidbuniatcommented, Apr 21, 2022

Added minor Option 1.2 that replaces index: int in __get_item__ with item: hub.Sample that can directly return torch tensor.

class CustomDataset(hub.TorchDataset):
    def __init__(self, path, **kwargs):
        super().__init__(path, **kwargs) # hub.load automatically happens here
        # some init

    def __getitem__(self, item: hub.Sample):
        # item['tensor_name'] will return directly torch.tensor 
        # slightly tricky with complex htypes such as json or dicom 
        images, labels = item 
        # do some transformation	
        # difficult to enable JIT optimizations here, but we can give a try 
        return image, label

Do you think this would create a lot of limitations for users?

1reaction
davidbuniatcommented, Apr 21, 2022

The only issue with Option 1 is that it makes quite tricky (or nearly impossible) to do JIT optimizations on transform function or future near-zero copy transfer to GPUs.

So we might keep this option for easy onboarding, but for speed and performance come up with another option if the users prefer so.

Read more comments on GitHub >

github_iconTop Results From Across the Web

[FEATURE] Improve PyTorch-user adoption by better aligning ...
Here's how I am currently porting activeloop datasets into a format that conforms with the standard way that Pytorch datasets are built. I...
Read more >
Writing Custom Datasets, DataLoaders and Transforms
PyTorch provides many tools to make data loading easy and hopefully, to make your code more readable. In this tutorial, we will see...
Read more >
How to use Datasets and DataLoader in PyTorch for custom ...
a Dataset stores all your data, and Dataloader is can be used to iterate through the data, manage batches, transform the data, and...
Read more >
Complete Guide to the DataLoader Class in PyTorch
This post covers the PyTorch dataloader class. We'll show how to load built-in and custom datasets in PyTorch, plus how to transform and...
Read more >
PyTorch: How to use DataLoaders for custom Datasets
Yes, that is possible. Just create the objects by yourself, e.g. import torch.utils.data as data_utils train = data_utils.TensorDataset(features, targets) ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found