question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

DataLoader: Dynamic Batch-size based on num_nodes/num_edges

See original GitHub issue

❓ Questions & Help

I was using the RGCNConv layer with num_relations=4, in_channels=out_channels=512. In the forward step, I pass in a graph with 300 nodes and 22768 edges which caused it to raise an CUDA OOM saying it needs 22.23GB memory while I only have 11.17GB.

The line causing the error is w = torch.index_select(w, 0, edge_type) in the message function of RGCNConv class, which makes sense as it was trying to create a float tensor of size [22768 x 512 x 512].

But it seems copying the weights num_edges times is inefficient and not necessary. Is there another way to implement RGCN without copying the weights this many times?

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:13 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
Zacharias030commented, Feb 5, 2020

Here’s a first hack of the above with all the caveats included:

import torch.utils.data
from torch.utils.data.dataloader import default_collate

from torch_geometric.data import Data, Batch
from torch._six import container_abcs, string_classes, int_classes


class NodeLimitedDataLoader(torch.utils.data.DataLoader):
    r"""Data loader which merges data objects from a
    :class:`torch_geometric.data.dataset` to a mini-batch.

    Args:
        dataset (Dataset): The dataset from which to load the data.
        batch_size (int, optional): How many samples per batch to load.
            (default: :obj:`1`)
        shuffle (bool, optional): If set to :obj:`True`, the data will be
            reshuffled at every epoch. (default: :obj:`False`)
        follow_batch (list or tuple, optional): Creates assignment batch
            vectors for each key in the list. (default: :obj:`[]`)
    """
    def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=[],
                 max_num_nodes=None, **kwargs):
        self.max_num_nodes = max_num_nodes

        def collate(batch):
            elem = batch[0]
            if isinstance(elem, Data):
                # greedily add all samples that fit within self.max_num_nodes
                # and silently discard all others
                if max_num_nodes is not None:
                    num_nodes = 0
                    limited_batch = []
                    for elem in batch:
                        if num_nodes + elem.num_nodes <= self.max_num_nodes:
                            limited_batch.append(elem)
                            num_nodes += elem.num_nodes
                    return Batch.from_data_list(limited_batch, follow_batch)
                else:
                    return Batch.from_data_list(batch, follow_batch)
            elif isinstance(elem, float):
                return torch.tensor(batch, dtype=torch.float)
            elif isinstance(elem, int_classes):
                return torch.tensor(batch)
            elif isinstance(elem, string_classes):
                return batch
            elif isinstance(elem, container_abcs.Mapping):
                return {key: collate([d[key] for d in batch]) for key in elem}
            elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
                return type(elem)(*(collate(s) for s in zip(*batch)))
            elif isinstance(elem, container_abcs.Sequence):
                return [collate(s) for s in zip(*batch)]

            raise TypeError('DataLoader found invalid type: {}'.format(
                type(elem)))

        super(NodeLimitedDataLoader,
              self).__init__(dataset, batch_size, shuffle,
                             collate_fn=lambda batch: collate(batch), **kwargs)
1reaction
rusty1scommented, Aug 12, 2019

This is a good request and we currently do nor support this. We could add an argument max_num_nodes or max_num_edges in addition to batch_size to the DataLoader. However, this requires us to implement our own data loading routine without relying on PyTorch to do this task for us so it can be a bit tricky to get right, especially in combination with num_workers.

Read more comments on GitHub >

github_iconTop Results From Across the Web

DataLoader: Dynamic Batch-size based on num_nodes ...
I was using the RGCNConv layer with num_relations=4, in_channels=out_channels=512. In the forward step, I pass in a graph with 300 nodes and ...
Read more >
torch.utils.data — PyTorch 1.13 documentation
This allows easier implementations of chunk-reading and dynamic batch size (e.g., by yielding a batched sample at each time).
Read more >
Data loading with variable batch size? - Stack Overflow
The following code snippet works for your purpose. First, we define a ToyDataset which takes in a list of tensors ( tensors )...
Read more >
torch_geometric.loader — pytorch_geometric documentation
Dynamically adds samples to a mini-batch up to a maximum size (either based on number of nodes or number of edges). class DataLoader(dataset: ......
Read more >
Efficient Dynamic Batching of Large Datasets with Infinibatch
Let's implement a typical dynamic padding workflow with pytorch dataloader and a subword level tokenizer. We use BERT-base-cased tokenizer ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found