Variable slice/index assignment graph breaking
See original GitHub issueI have been facing an issue when trying to create a graph of a Module in which some Variables have slice assignment operations in them. I have reduced the problem to the following example, ignore the commented out V = x
for now.
import torch
from torch.autograd import Variable
from tensorboardX import SummaryWriter
class DummyModule(torch.nn.Module):
def forward(self, x):
V = Variable(torch.Tensor(2, 2))
V[0, 0] = x
# V = x
return torch.sum(V * 3)
x = Variable(torch.Tensor([1]), requires_grad=True)
r = DummyModule()(x)
r.backward()
print(x.grad)
w = SummaryWriter()
x = Variable(torch.Tensor([1]), requires_grad=True)
w.add_graph(DummyModule(), x, verbose=True)
The output from this is below, showing that the gradients are flowing all right, but the graph is not being connected. If I insert another input Variable and other operations in the Module, add_graph()
works fine without throwing an error, but the graph show a disconnected input for x
, so I suppose the nature of this error is that the only input Variable available is being interpreted as disconnected.
Variable containing:
3
[torch.FloatTensor of size (1,)]
Traceback (most recent call last):
File "test_grad.py", line 21, in <module>
w.add_graph(DummyModule(), x, verbose=True)
File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/tensorboardX/writer.py", line 400, in add_graph
self.file_writer.add_graph(graph(model, input_to_model, verbose))
File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/tensorboardX/graph.py", line 44, in graph
trace, _ = torch.jit.trace(model, args)
File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/torch/jit/__init__.py", line 251, in trace
return TracedModule(f, nderivs=nderivs)(*args, **kwargs)
File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
result = self.forward(*input, **kwargs)
File "/Users/filiped/anaconda/envs/pytorch0.4/lib/python3.6/site-packages/torch/jit/__init__.py", line 287, in forward
torch._C._tracer_exit(out_vars)
RuntimeError: /Users/filiped/pytorch/torch/csrc/jit/tracer.h:117: getTracingState: Assertion `state` failed.
Moreover, if you uncomment the line V = x
and comment the line above it, so that no slice/index assign operation is performed, you get, as expected:
Variable containing:
3
[torch.FloatTensor of size (1,)]
graph(%0 : Float(1)) {
%1 : UNKNOWN_TYPE = Constant[value={3}](), scope: DummyModule
%2 : Float(1) = Mul[broadcast=1](%0, %1), scope: DummyModule
%3 : Float() = Sum(%2), scope: DummyModule
return (%3);
}
This was all executed in Pytorch 0.4
(Edits: Did a couple rounds of re-simplifying the example.)
Issue Analytics
- State:
- Created 6 years ago
- Comments:18 (10 by maintainers)
Top GitHub Comments
update: still not working in pytorch 0.4 release + tensorboardX master. output of tensorboardX:
Interesting, I closed this because the onnx error is disappeared. Reopen it.