question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Community Sprint] Add missing type hints and TorchScript support 🚀

See original GitHub issue

🚀 The feature, motivation and pitch

We are kicking off our very first community sprint!

The community sprint resolves around adding missing type hints and TorchScript support for various functions across PyG, aiming to improve and clean-up our core codebase. Each individual contribution is designed to only take around 30 minutes to two hours to complete.

The sprint begins Wednesday October 12th with a kick off meeting at 8am PST. The community sprint will last 2 weeks and we will have another live hangouts when the sprint has completed. If you are interested in helping out, please also join our PyG slack channel #community-sprint-type-hints for more information. You can assign yourself to the type hints you are planning to work on here.

🚀 Add missing type hints and TorchScript support

Type hints are currently used inconsistently in the torch-geometric repository, and it would be nice to make them a complete, consistent thing across all datasets, models and utilities. Adding type hint support in models also helps us to improve our TorchScript coverage across layers and models provided in nn.*.

Example

Take a look at the current implementation of contains_isolated_nodes:

def contains_isolated_nodes(edge_index, num_nodes=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    edge_index, _ = remove_self_loops(edge_index)
    return torch.unique(edge_index.view(-1)).numel() < num_nodes

Adding type hints support to the function signature helps us to better understand its input and output, improving code readability:

def contains_isolated_nodes(
   edge_index: Tensor,
   num_nodes: Optional[int] = None,
) -> bool:
   ...

Importantly, it also lets us use it as part of a TorchScript model. Without it, all arguments that miss type hints are expected to be PyTorch tensors (which is clearly not the case for the num_nodes argument). Without it, torch.jit.script compilation will fail:

import torch

from torch_geometric.utils import contains_isolated_nodes

contains_isolated_nodes = torch.jit.script(contains_isolated_nodes)

contains_isolated_nodes(torch.tensor([[0, 1, 0], [1, 0, 0]])
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File ".../pytorch_geometric/torch_geometric/utils/isolated.py", line 29, in contains_isolated_nodes
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
                ~~~~~~~~~~~~~~~ <--- HERE
    edge_index, _ = remove_self_loops(edge_index)
    return torch.unique(edge_index.view(-1)).numel() < num_nodes
RuntimeError: Cannot input a tensor of dimension other than 0 as a scalar argument

Guide to contributing

See here for a basic example to follow.

  1. Ensure you have read our contributing guidelines.
  2. Claim the functionality/model you want to improve here.
  3. Implement the changes as in https://github.com/pyg-team/pytorch_geometric/pull/5603. At best, ensure in a test that the model/function is convertable to a TorchScript program and that it results in the same output. It is okay to add type hint support for multiple functions/models within a single PR as long as you have assigned yourself to each of them, and the number of file changes stays at a reasonable number to ease reviewing (e.g., not more than 10 touched files).
  4. Open a PR to the PyG repository and name it: “[Type Hints] {model_name/function_name}. In addition, ensure that documentation is rendered properly (CI will build the documentation automatically for you). Afterwards, add your PR number to the “Improved type hint support” line in CHANGELOG.md.

Tips for making your PR

  • If you are unfamiliar with how type hints work, you can read the Python library documentation on them, but it is probably even easier to just look at another PR that added them.
  • The types will usually be obvious, e.g., Tensor, bool, int, float. Wrap them within Optional[*] whenever the argument can be None.
  • Specialized PyG type hints (e.g., Adj) are defined in typing.py.
  • In some rare cases, type hints are challenging to add, e.g., whenever a model/function supports a Union of different types (which may be the case for functions/models that also support SparseTensor, e.g., edge_index: Union[Tensor, SparseTensor]). In that case, TorchScript support can be achieved via the torch.jit._overload decorator. See here for an example.
  • The corresponding tests of PyG models and functions can be found in the test/ directory. For example, tests for torch_geometric/utils/isolated.py can be found in test/utils/test_isolated.py. You can run individual test files via pytest test/utils/test_isolated.py. It is only necessary to test TorchScript support for the nn.* and utils.* packages. No changes necessary for datasets.* and transforms.* packages.
  • Ensure that the TorchScript variant compiles and achieves the same output:
    from torch_geometric.testing import is_full_test
    
    ...
    
    edge_index = torch.tensor([[0, 1, 2, 0], [1, 0, 2, 0]])
    assert contains_isolated_nodes(edge_index)
      
    if is_full_test():
        jit = torch.jit.script(contains_isolated_nodes)
        assert jit(edge_index)
    
    Note that we generally gate TorchScript tests behind is_full_test() which guarantees that TorchScript tests are only run nightly. You can enable full tests locally via the FULL_TEST=1 environment variable, e.g., FULL_TEST=1 pytest test/utils/test_isolated.py. It is only necessary to ensure TorchScript support for the nn.* and utils.* packages. No changes necessary for datasets.* and transforms.* packages.

Functions/Models to update

This list may be incomplete. If you still find a function without missing type hints/TorchScript tests, please let us know or add them on your own.

  • nn.MetaLayer (#5758)
  • nn.ChebConv (#5730)
  • nn.InstanceNorm (#5684)
  • nn.GraphSizeNorm (#5729)
  • nn.MessageNorm (#5847)
  • nn.DiffGroupNorm (#5757)
  • nn.TopKPooling (#5731)
  • nn.SAGPooling (#5810)
  • nn.EdgePooling (#5738)
  • nn.PANPooling (#5852)
  • nn.max_pool (#5735)
  • nn.avg_pool (#5734)
  • nn.max_pool_x (#5687)
  • nn.max_pool_neighbor_x (#5703)
  • nn.avg_pool_x (https://github.com/pyg-team/pytorch_geometric/pull/5706)
  • nn.avg_pool_neighbor_x (#5707)
  • nn.Node2Vec (#5669)
  • nn.DeepGraphInfomax (#5688)
  • nn.InnerProductDecoder (#5842)
  • nn.GAE (#5842)
  • nn.VGAE (#5842)
  • nn.ARGA (#5726)
  • nn.ARGVA (#5726)
  • nn.SignedGCN (#5725)
  • nn.RENet (#5715)
  • nn.GraphUNet (#5710)
  • nn.SchNet (#5710)
  • nn.DimeNet (#5806)
  • nn.DimeNetPlusPlus (#5806)
  • nn.GNNExplainer (#5716)
  • nn.DeepGCNLayer (#5699)
  • nn.AttentiveFP(#5766)
  • nn.DenseGCNConv (#5664)
  • nn.DenseGINConv (#5842)
  • nn.DenseGraphConv (#5842)
  • nn.DenseSAGEConv (#5664)
  • nn.dense_diff_pool (#5754)
  • nn.dense_mincut_pool (#5756)
  • datasets.NELL (#5678)
  • datasets.PPI (#5678)
  • datasets.Reddit (#5695)
  • datasets.Reddit2 (#5695)
  • datasets.Yelp (#5747)
  • datasets.QM7b (#5678)
  • datasets.ZINC (#5678)
  • datasets.MoleculeNet (#5678)
  • datasets.MNISTSuperpixels (#5760)
  • datasets.ShapeNet (#5828)
  • datasets.ModelNet (#5701)
  • datasets.SHREC2016 (#5798)
  • datasets.TOSCA (#5797)
  • datasets.PCPNetDataset (#5797)
  • datasets.S3DIS (#5799)
  • datasets.ICEWS18 (#5666)
  • datasets.WILLOWObjectClass (#5781)
  • datasets.PascalVOCKeypoints (#5781)
  • datasets.PascalPF (#5781)
  • datasets.SNAPDataset (#5811)
  • datasets.SuiteSparseMatrixCollection (#5811)
  • datasets.WordNet18 (#5811)
  • datasets.WordNet18RR (#5811)
  • datasets.WebKB (#5778)
  • datasets.JODIEDataset (#5797)
  • datasets.MixHopSyntheticDataset (#5797)
  • datasets.UPFD (#5800)
  • transforms.Distance (#5685)
  • transforms.Cartesian (#5673)
  • transforms.LocalCartesian (#5675)
  • transforms.Polar (#5676)
  • transforms.Spherical (#5732)
  • transforms.PointPairFeatures (#5732)
  • transforms.OneHotDegree (#5667)
  • transforms.TargetIndegree (#5743)
  • transforms.RandomJitter (#5714)
  • transforms.RandomFlip (#5714)
  • transforms.RandomScale (#5714)
  • transforms.RandomRotate (#5714)
  • transforms.RandomShear (https://github.com/pyg-team/pytorch_geometric/pull/5702/)
  • transforms.KNNGraph (#5722)
  • transforms.FaceToEdge (#5722)
  • transforms.SamplePoints (#5733)
  • transforms.FixedPoints (#5733)
  • transforms.ToDense (#5668)
  • transforms.LaplacianLambdaMax (#5733)
  • transforms.ToSLIC (#5743)
  • transforms.GDC (#5752)
  • transforms.SIGN (#5736)
  • transforms.SVDFeatureReduction (#5743)
  • transforms.AddSelfLoops (#5753)
  • transforms.Center (#5753)
  • transforms.Compose (#5753)
  • transforms.Delaunay (#5753)
  • transforms.GCNNorm (#5753)
  • transforms.GenerateMeshNormals (#5753)
  • transforms.LinearTransformation (#5753)
  • transforms.LocalDegreeProfile (#5753)
  • transforms.NormalizeFeatures (#5753)
  • transforms.NormalizeRotation (#5753)
  • transforms.NormalizeScale (#5753)
  • transforms.RadiusGraph (#5753)
  • transforms.RandomLinkSplit (#5753)
  • transforms.RandomNodeSplit (#5753)
  • transforms.RemoveIsolatedNodes (#5753)
  • transforms.ToDevice (#5753)
  • transforms.ToSparseTensor (#5753)
  • transforms.ToUndirected (#5753)
  • transforms.TwoHop (#5753)
  • utils.remove_isolated_nodes (#5659)
  • utils.get_laplacian (#5682)
  • utils.to_dense_adj (#5682)
  • utils.dense_to_sparse (#5683)
  • utils.normalized_cut (#5665)
  • utils.grid (#5724)
  • utils.geodesic_distance (#5851)
  • utils.tree_decomposition (#5851)
  • utils.to_scipy_sparse_matrix (#5736)
  • utils.from_scipy_sparse_matrix (#5736)
  • utils.to_networkx (#5768)
  • utils.from_networkx (#5768)
  • utils.erdos_renyi_graph (#5768)
  • utils.stochastic_blockmodel_graph (#5768)
  • utils.barabasi_albert_graph (#5768)
  • utils.train_test_split_edges (#5737)

Alternatives

No response

Additional context

No response

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:2
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

4reactions
rusty1scommented, Oct 29, 2022

We reached 100% coverage of the community sprint! 🎉 This is amazing to see! I like to deeply thank everyone who participated in this sprint. Thanks for improving our code base and for contributing to PyG!

1reaction
rusty1scommented, Oct 17, 2022

Yeah, if you have found some transforms not included here, please add them to the list 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

TorchScript Language Reference - PyTorch
TorchScript does not support all features and types of the typing module. ... It is also possible to annotate types with Python 3...
Read more >
What is the recommended way to sum a list in TorchScript
Easiest approach would be to use PyTorch's sum directly: class Model(nn.Module): def __init__(self): super().__init__() def forward(self, ...
Read more >
https://huggingface.co/m3hrdadfi/bert2bert-fa-wiki...
... +عرصه +بزن +##لاک +کاخ +##ed +امپرات +نماد +##یترین +##غیر +##مندی +تمرین +ممن +##اشق +## ad +تابستان +مسيولیت +ریس +وابسته +##لت +یازدهم +مقدس...
Read more >
Mastering TorchScript: Tracing vs Scripting, Device Pinning ...
TorchScript is one of the most important parts of the Pytorch ecosystem, ... A useful practice is to use type hints in method...
Read more >
Pytorch amp autocast - fextra-global
Adding autocast Instances of torch.cuda.amp.autocast serve as context managers that allow regions of your script to run in mixed precision.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found