TypeError: Gradient only defined for scalar-output functions. Output had shape: (1,)
See original GitHub issueI ran into the error in the title:
TypeError: Gradient only defined for scalar-output functions. Output had shape: (1,)
Should jax not treat arrays with shape (1,)
as scalars?
I also feel like this had worked before.
When i tried to modify the function I ran into problems that so far I don’t understand and cannot reproduce:
- When i made the function return an actual scalar (by taking the
[0]
slice) i ran into an Intel MKL Error that i haven’t managed to pinpoint since the python debugger is not helpful here. - When I tried to modify the function by returning
previous_version.item()
, I got anAttributeError
, saying that'ConcreteArray' object has no attribute 'item'
I guess the simple example below does not run into any of this, since the computation tree is trivial and it can jsut return zeros without doing any evaluations on the gradient pass
I am a bit lost how I can actually compute a gradient of my function now…
Minimal reproducer:
from jax import numpy as np, value_and_grad
print(value_and_grad(lambda x: np.array([0.]))(np.array([1, 2, 3.])))
Issue Analytics
- State:
- Created 4 years ago
- Reactions:6
- Comments:10 (7 by maintainers)
Top Results From Across the Web
Gradient only defined for scalar-output functions. Output had ...
The grad function requires real or complex-valued inputs but you are using integers. This chunk will work import jax def model(x): return ...
Read more >JAX As Accelerated NumPy
Many common NumPy programs would run just as well in JAX if you substitute np for jnp . ... TypeError: Gradient only defined...
Read more >An example with with autodiff - Kyle Cranmer
Gives error: # Gradient only defined for scalar-output functions. Output had shape: (10000,). Start thebelab interactive mode.
Read more >Grad only applies to real scalar-output functions
Hi, i have the following quantum circuit. I got the error “TypeError: Grad only applies to real scalar-output functions.
Read more >TF_JAX_Tutorials - Part 9 (Autodiff in JAX) | Kaggle
The grad function in JAX is used for computing the gradients. ... shape: (5,) TypeError Gradient only defined for scalar-output functions.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
FYI, to be totally honest, I actually need the first and second derivative of my function
f
. I can use your trick to reshape to()
for the first derivative, but I couldn’t make that work with the second derivative since grad outputs a shape of(nparticles, ndim)
and there isn’t a way to map that to scalar.Thanks for the idea, using
np.reshape
. Unfortunately, it raises the same Intel MKL Error. I will try to get to the bottom of that one after i got some sleep.If the expectations of a non-sage are of interest: I expect option 1., i.e. I expect
x.item()
to behave likemy reasoning is that
.item()
grammatically (item ≠ items) suggests that the array only contains a single scalar value. so i will only use it if i know (or err in assuming) that it does have a single value. so i would want an error if that assumption is wrong.