question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Learnable edge weight in GCNConv

See original GitHub issue

❓ Questions & Help

Hi Is it possible to set the edge_weight argument in GCNConv as a learnable parameter? I am attaching a code snippet related to this:

def __init__(self,input_channels,output_channels):
          self.conv1=GCNConv(output_channels,output_channels,node_dim=1,bias=False,normalize=False)
          self.conv1.weight=nn.Parameter(torch.ones_like(self.conv1.weight,dtype=torch.float64))
          self.register_parameter('edge_weight',None)

def reset_parameters(self,edge_index_t):
	 self.edge_weight = nn.Parameter(torch.ones((edge_index_t.size(1),),dtype=torch.float64,requires_grad=True))
def forward(self,data):
      x,edge_index_t = data.x,data.edge_index
       edge_index_t,__ =utils.add_self_loops(edge_index_t)
      if(self.edge_weight is None):
	   self.reset_parameters(edge_index_t)
      x=x.permute(1,0,2).contiguous()
      x=self.conv1(x,edge_index_t,edge_weight=self.edge_weight)
      x=x.permute(1,0,2).contiguous()
     data.x = x
    return data

As far as I can understand, the edge_index parameter is used to generate the adjacency matrix that determines which node’s features we are supposed to consider for convolution. In the pytorch geometric documentation, https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv , it is mentioned that the values of the adjacency matrix can be set to values other than 1 by the edge_weight. My intention is to make the values present in the adjacency matrix learnable. When I tried executing the loss.backward() and printed the values of self.edge_weight, the gradients were computed for it but the weights weren’t updated. How do I solve this?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

3reactions
rusty1scommented, Jan 22, 2021

I cannot really reproduce this. This works just fine for me:

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv

dataset = Planetoid('/tmp/Planetoid', 'Cora', transform=T.NormalizeFeatures())
data = dataset[0]


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.edge_weight = torch.nn.Parameter(torch.ones(data.num_edges))
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv1(x, edge_index, self.edge_weight.sigmoid()).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, self.edge_weight.sigmoid())
        return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    F.nll_loss(out[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


@torch.no_grad()
def test():
    model.eval()
    logits, accs = model(data.x, data.edge_index), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs


best_val_acc = test_acc = 0
for epoch in range(1, 201):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    print(model.edge_weight[:5])
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    print(log.format(epoch, train_acc, best_val_acc, test_acc))
1reaction
rusty1scommented, Jan 22, 2021

I see, this is indeed tricky. You cannot simply treat trainable parameters as data in PyTorch. An alternative is to create a huge Parameter tensor (for every edge that exists in the dataset), and to then select the edge weights you need during training:

edge_weight = Parameter(torch.ones(num_all_edges))

def forward(self, x, edge_index, e_id):
    current_edge_weight = edge_weight[e_id]
    ...

where e_id denotes the ids of edges for that single forward pass.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Learnable edge weight in GCNConv · Issue #2033 - GitHub
My intention is to make the values present in the adjacency matrix learnable. When I tried executing the loss.backward() and printed the values ......
Read more >
torch_geometric.nn — pytorch_geometric documentation
The dynamic edge convolutional operator from the "Dynamic Graph CNN for Learning on Point Clouds" paper (see torch_geometric.nn.conv.
Read more >
Source code for torch_geometric.nn.conv.gcn_conv
[docs]class GCNConv(MessagePassing): r"""The graph convolutional operator ... where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to ...
Read more >
Introduction by Example — pytorch_geometric documentation
A graph is used to model pairwise relations (edges) between objects (nodes). A single graph in PyG is described by an instance of...
Read more >
GNN Cheatsheet — pytorch_geometric documentation
Name SparseTensor edge_weight bipartite static lazy GCNConv (Paper) ✓ ✓ ✓ ✓ ChebConv (Paper) ✓ ✓ ✓ SAGEConv (Paper) ✓ ✓ ✓ ✓
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found