question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Add `from_data` builder method to `Batch` that accepts batched data.

See original GitHub issue

🚀 The feature, motivation and pitch

PyG currently offers only one method for creating Batch objects, namely from_data_list. This method takes a list of BaseData objects as input. However, there are settings, such as Reinforcement Learning, where it is commonplace to process data that is already batched, e.g. using stable_baselines3 SubprocVecEnv. In these settings, it becomes necessary to unbatch the data and subsequently re-batch it using from_data_list. It would be helpful if there was a method that accepted already batched data in order to create Batch objects.

Alternatives

Currently our best alternatives are un-batching and re-batching, as described above, or overriding the methods in SubprocVecEnv.

Additional context

I am willing to work on a Pull Request for this.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
yikai5518commented, Aug 27, 2022

For x, we reshaped with the padding as the padded nodes will not induce message passing since they are disjoint from the rest of the graphs. For edge_index, we used the inc tensor you suggested, and we removed the padded edges with boolean indexing. A similar approach was used for edge_attr.

The code that we used is as follows:

# x.shape = batch_size * max_num_nodes * embedding_size
# edge_index.shape = batch_size * 2 * max_num_edges
# edge_attr.shape = batch_size * max_num_edges * embedding_size

def collate(x, edge_index, edge_attr):
    b, num_nodes, num_features = x.shape
    num_edges = torch.count_nonzero(edge_index + 1, dim=2)[:, 0]
    x = x.reshape(b * num_nodes, num_features)
    
    inc = torch.arange(0, b * num_nodes, num_nodes, device=edge_index.device).reshape(b, 1, 1)
    edge_index = (edge_index + inc).transpose(-2, -1)
    
    max_num_edges = edge_index.shape[1]
    num_features = edge_attr.shape[-1]
    
    grid = torch.arange(max_num_edges, device=edge_index.device).repeat(b, 2, 1).transpose(-2, -1)
    mask = grid < num_edges.reshape(b, 1, 1)
    edge_index = edge_index[mask].reshape(-1, 2).transpose(-2, -1)
    
    grid = torch.arange(max_num_edges, device=edge_attr.device).repeat(b, num_features, 1).transpose(-2, -1)
    mask = grid < num_edges.reshape(b, 1, 1)
    edge_attr = edge_attr[mask].reshape(-1, num_features)
    
    return x, edge_index, edge_attr

For separate(), since we only needed the nodes, we just reshaped x back to the original shape.

Thanks for your help in this issue!

0reactions
ethanabrookscommented, Aug 26, 2022

Hi @rusty1s a member of my team came up with a solution to this based on what you wrote. We will be sharing shortly. Thanks for your help!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Batch Processing in Spring Boot Simplified 101 - Learn | Hevo
The StepBuilderFactory class generates a batch step. The batch steps will be executed by the batch job. Batch jobs such as ItemReader, ...
Read more >
Adding batch or bulk endpoints to your REST API - Codementor
A comprehensive guide on what batch endpoints are, why they're useful, and how they can be added to existing REST APIs.
Read more >
Sending records in batches - Algolia
Go to your dashboard, select the Data Sources icon and then select your index. · Click the Add records tab and select Add...
Read more >
Data Collator - Hugging Face
Data collators are objects that will form a batch by using a list of dataset elements as input. These elements are of the...
Read more >
Efficient way to do batch INSERTS with JDBC - Stack Overflow
This is a mix of the two previous answers: PreparedStatement ps = c.prepareStatement("INSERT INTO employees VALUES (?, ?)"); ps.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found