Learnable edge weight in GCNConv
See original GitHub issue❓ Questions & Help
Hi Is it possible to set the edge_weight argument in GCNConv as a learnable parameter? I am attaching a code snippet related to this:
def __init__(self,input_channels,output_channels):
self.conv1=GCNConv(output_channels,output_channels,node_dim=1,bias=False,normalize=False)
self.conv1.weight=nn.Parameter(torch.ones_like(self.conv1.weight,dtype=torch.float64))
self.register_parameter('edge_weight',None)
def reset_parameters(self,edge_index_t):
self.edge_weight = nn.Parameter(torch.ones((edge_index_t.size(1),),dtype=torch.float64,requires_grad=True))
def forward(self,data):
x,edge_index_t = data.x,data.edge_index
edge_index_t,__ =utils.add_self_loops(edge_index_t)
if(self.edge_weight is None):
self.reset_parameters(edge_index_t)
x=x.permute(1,0,2).contiguous()
x=self.conv1(x,edge_index_t,edge_weight=self.edge_weight)
x=x.permute(1,0,2).contiguous()
data.x = x
return data
As far as I can understand, the edge_index parameter is used to generate the adjacency matrix that determines which node’s features we are supposed to consider for convolution. In the pytorch geometric documentation, https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv , it is mentioned that the values of the adjacency matrix can be set to values other than 1 by the edge_weight. My intention is to make the values present in the adjacency matrix learnable. When I tried executing the loss.backward() and printed the values of self.edge_weight, the gradients were computed for it but the weights weren’t updated. How do I solve this?
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (3 by maintainers)
Top GitHub Comments
I cannot really reproduce this. This works just fine for me:
I see, this is indeed tricky. You cannot simply treat trainable parameters as
data
in PyTorch. An alternative is to create a huge Parameter tensor (for every edge that exists in the dataset), and to then select the edge weights you need during training:where
e_id
denotes the ids of edges for that single forward pass.