Case where grad shape is wrong
See original GitHub issueI have come across a case where grad
of a scalar with respect to another scalar gives a vector. I have noticed at least one place where if you change the code you get a scalar (indicated by the fail clause).
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
from jax import numpy as jnp, vmap, value_and_grad
FAIL = True
def f2(y, z):
v1 = z
v2 = jnp.sum(y) + z
if FAIL:
return jnp.logaddexp(v1, v2)
else:
return v1 + v2
def f1(y, z):
v = vmap(lambda _y: f2(_y, z))(y)
return jnp.sum(v)
if __name__ == '__main__':
y = jnp.ones((3, 2))
f = lambda z: f1(y, z)
z = 0.1
val, g = value_and_grad(f)(z)
print(val.shape, g.shape)
assert val.shape == ()
assert g.shape == ()
jax.__version__ = 0.2.9
jaxlib.__version__ = 0.1.61
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (10 by maintainers)
Top Results From Across the Web
RuntimeError occurs in PyTorch backward function
In any case, if you need to pass gradient to the out.backward() , make sure that it has the same shape as the...
Read more >Invalid gradient shape after discarding filters during training
Hello, I'm trying to remove some filters during training, however, the .backward() in the second iteration raises an error due to the size ......
Read more >tf.gradients | TensorFlow v2.11.0
When grad_ys is None, we fill in a tensor of '1's of the shape of y for each y in ys . A...
Read more >Automatic differentiation package - torch.autograd
Automatic differentiation package - torch.autograd. torch.autograd provides classes and functions implementing automatic differentiation of arbitrary scalar ...
Read more >Homework 1 Part 1 | Deep Learning, CMU
Mistakes early on can lead to a cascade of problems later, ... In this case, grad fn is a BackwardFunction object.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
#6470 fixed it in both the MVCE and the larger algorithm.
I think I can get you a better fix sometime today.