question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Roadmap] `torch_geometric.nn.aggr` 🚀

See original GitHub issue

🚀 The feature, motivation and pitch

The goal of this roadmap is to unify the concepts of aggregation inside GNNs across both MessagePassing and global readouts. Currently, these concepts are separated, e.g., via MessagePassing.aggr = "mean" and global_mean_pool(...) while the underlying implementation is the same. In addition, some aggregations are only available as global pooling operators (global_sort_pool, Set2Set, …), while, in theory, they are also applicable during MessagePassing (and vice versa, e.g., SAGEConv.aggr = "lstm"). One additional feature is the combination of aggregations, which is a useful feature both in MessagePassing (PNAConv, EGConv, …) and global readouts.

As such, we want to provide re-usable aggregations as part of a newly defined torch_geometric.nn.aggr.* package. Unifying these concepts also helps us to perform optimization and specialized implementations in a single place (e.g., fused kernels for multiple aggregations). After integration, the following functionality is applicable:

class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="mean")

class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=LSTMAggr(channels=...))

class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr=MultiAggr("mean", "max", Set2Set(channels=...))

Roadmap

The general roadmap looks as follows (at best, each implemented in a separate PR):

  • Define torch_geometric.nn.aggr.* and implement a BaseAggr abstract class (#4687)
  • Add new aggregators:
    • DeepGCN aggregations SoftmaxAggr, PowerMeanAggr, cf., GENConv (#4687)
    • LSTMAggr, cf., SAGEConv (#4731)
  • Allow for multiple aggregations as part of a MultiAggr class (#4749)
  • Add support for class-resolver, similar to here (#4749, #4716)
  • Ensure torch.jit.script support (#4779)
  • Integrate new aggregations into MessagePassing interface (#4779)
  • Move aggregators from torch_geometric.nn.glob to torch_geometric.nn.aggr (respecting the new interface), deprecate old implementation:
    • MeanAggr, SumAggr, MaxAggr, MinAggr, MulAggr, VarAggr, StdAggr (#4687, #4749)
    • MedianAggr (#5098)
    • AttentionalAggr (#4986)
    • Set2Set (#4762)
    • GlobalSortAggr (#4957)
    • GraphMultiSetTransformer (#4973)
    • EquilibriumAggr (#4522)
    • Deprecate torch_geometric.nn.glob (#5039)
  • Update existing GNN layers to make use of new interface, e.g.:
    • SAGEConv (#4863)
    • PNAConv (#4864)
    • GravNetConv (#4865)
    • GENConv (#4866)
  • Add support for “reverse” aggregation resolver to keep message_and_aggregate functionality intact (#5084)
  • Support for multiple aggregations in SAGEConv (#5033)
  • MultiAggregation: Support for concat, concat+transform, sum, mean, max, attention (#5000, #5034)
  • Add semi_grad functionality to SoftmaxAggregation (#4995)
  • Add and verify torch_geometric.nn.aggr.* documentation (#5036, #5097, #5099, #5104)
  • Add a tutorial on the new concepts (#4927)
  • Kernel fusion: Optimize aggregations, e.g., by computing multiple aggregations in parallel (at best discussed in a separate issue)

Any feedback and help from the community is highly appreciated!

cc: @lightaime @Padarn

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:6
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
rusty1scommented, Aug 12, 2022

Thanks everyone for the hard work @lightaime @Padarn. I think that the final outcome looks fantastic - many cool things to promote in our upcoming release 😃

1reaction
rusty1scommented, Jul 17, 2022

Yes, indeed. There do not exist clear plans for implementation yet though. It will likely depend on PyTorch fusing this ops as part of TorchScript, or on us providing special CUDA kernels.

Read more comments on GitHub >

github_iconTop Results From Across the Web

torch_geometric.nn — pytorch_geometric documentation
nn . Contents. Convolutional Layers. Aggregation Operators. Normalization Layers. Pooling Layers. Unpooling Layers.
Read more >
torch_geometric.nn.aggr.attention - PyTorch Geometric
Source code for torch_geometric.nn.aggr.attention ... Built with Sphinx using a theme provided by Read the Docs. Read the Docs v: latest.
Read more >
torch_geometric.nn.aggr.basic - PyTorch Geometric
Source code for torch_geometric.nn.aggr.basic ... Built with Sphinx using a theme provided by Read the Docs. Read the Docs v: latest.
Read more >
torch_geometric.nn.aggr.scaler - PyTorch Geometric
from typing import Any, Dict, List, Optional, Union import torch from torch import Tensor from torch_geometric.nn.aggr import Aggregation, MultiAggregation ...
Read more >
torch_geometric.nn.aggr.multi - PyTorch Geometric
Source code for torch_geometric.nn.aggr.multi. import copy from typing import Any, Dict, List, Optional, Union import torch from torch import Tensor from ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found