question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Case where grad shape is wrong

See original GitHub issue

I have come across a case where grad of a scalar with respect to another scalar gives a vector. I have noticed at least one place where if you change the code you get a scalar (indicated by the fail clause).

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
from jax import numpy as jnp, vmap, value_and_grad

FAIL = True


def f2(y, z):
    v1 = z
    v2 = jnp.sum(y) + z

    if FAIL:
        return jnp.logaddexp(v1, v2)
    else:
        return v1 + v2


def f1(y, z):
    v = vmap(lambda _y: f2(_y, z))(y)
    return jnp.sum(v)


if __name__ == '__main__':
    y = jnp.ones((3, 2))
    f = lambda z: f1(y, z)
    z = 0.1
    val, g = value_and_grad(f)(z)
    print(val.shape, g.shape)
    assert val.shape == ()
    assert g.shape == ()

jax.__version__ = 0.2.9
jaxlib.__version__ = 0.1.61

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:10 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
Joshuaalbertcommented, Apr 16, 2021

#6470 fixed it in both the MVCE and the larger algorithm.

1reaction
mattjjcommented, Apr 15, 2021

I think I can get you a better fix sometime today.

Read more comments on GitHub >

github_iconTop Results From Across the Web

RuntimeError occurs in PyTorch backward function
In any case, if you need to pass gradient to the out.backward() , make sure that it has the same shape as the...
Read more >
Invalid gradient shape after discarding filters during training
Hello, I'm trying to remove some filters during training, however, the .backward() in the second iteration raises an error due to the size ......
Read more >
tf.gradients | TensorFlow v2.11.0
When grad_ys is None, we fill in a tensor of '1's of the shape of y for each y in ys . A...
Read more >
Automatic differentiation package - torch.autograd
Automatic differentiation package - torch.autograd. torch.autograd provides classes and functions implementing automatic differentiation of arbitrary scalar ...
Read more >
Homework 1 Part 1 | Deep Learning, CMU
Mistakes early on can lead to a cascade of problems later, ... In this case, grad fn is a BackwardFunction object.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found