Running Error: Trying to backward through the graph a second time
See original GitHub issue🐛 Describe the bug
When I define a GAT net, and pass its output to BART model’s decoder, which uses for cross attention. The bart decoder has 12 attention layers. And I will get an Error: Trying to backward through the graph a second time. When I don‘t use GATConv layer, it won’t happen. Appreciate your help!!! The net is as below:
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=2):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
self.conv2 = GATConv(heads*hidden_channels, out_channels, heads=1, concat=False, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=-1)
Environment
- PyG version: 1.7.2
- PyTorch version: 1.9.1
- OS: Linux
- Python version: 3.8
- CUDA/cuDNN version: 11.1
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
RuntimeError: Trying to backward through the graph a second ...
I keep running into this error: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed....
Read more >Pytorch - RuntimeError: Trying to backward through the graph ...
The problem is from my training loop: it doesn't detach or repackage the hidden state in between batches? If so, then loss.backward() is ......
Read more >Trying to backward through the graph a second time ... - GitHub
RuntimeError : Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed).
Read more >Trying to backward through the graph a second time, but the ...
I keep running into this error: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed....
Read more >understanding backward() in Pytorch-checkpoint
But at the very beginning, I was very confused by the backward() function when ... else: RuntimeError: Trying to backward through the graph...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Do you have a minimal example to reproduce this? This would be helpful to track down the issue. Your
GAT
definition looks fine though.Thanks for confirming. Hope that you can resolve your issue 😃 Eventually someone in the PyTorch forums can help you further.