torch.einsum() compatibility
See original GitHub issueHi, i am testing your sample code in attention_augmented_conv.py with:
tmp = torch.randn((16, 3, 32, 32)) a = AugmentedConv(3, 20, kernel_size=3, dk=40, dv=4, Nh=2, relative=True) print(a(tmp).shape)
But it raises:
Traceback (most recent call last): File "attention_augmented_conv.py", line 131, in <module> print(a(tmp).shape) File "/Users/scouly/anaconda3/envs/Pytorch_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 477, in __call__ result = self.forward(*input, **kwargs) File "attention_augmented_conv.py", line 44, in forward h_rel_logits, w_rel_logits = self.relative_logits(q) File "attention_augmented_conv.py", line 90, in relative_logits rel_logits_w = self.relative_logits_1d(q, key_rel_w, H, W, Nh, "w") File "attention_augmented_conv.py", line 99, in relative_logits_1d rel_logits = torch.einsum('bhxyd,md->bhxym', q, rel_k) TypeError: einsum() takes 2 positional arguments but 3 were given
I’m guessing if it’s caused by the version compatibility issue of pytorch.
BTW i am currently using pytorch 0.4.1 on Mac OS
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (2 by maintainers)
Top GitHub Comments
Thanks!
How to change the padding of the convolution layer for example (0)?