nan when using np.where and grad
See original GitHub issueHere is the situation:
f = lambda x: np.where(x<1., x, np.sqrt(x))
df = grad(f)
df(1e-9) # evaluates to 1, as expected
df(0.0) # evaluates to nan, presumably because the second case of the where is nan in that case
# but it is the first case that should give the result, not the second!
Not sure if this is expected behavior, but it is surprising. I’m going to work on a work-around, but thought I should document it. Thanks!
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
How to avoid NaN when using np.where function in python?
You can use fillna(False) . You are using Boolean indexing so always the values corresponding to NaN will be 0
Read more >jax.numpy.where - JAX documentation
Special care is needed when the x or y input to jax.numpy.where() could have a value of NaN. Specifically, when a gradient is...
Read more >numpy.gradient — NumPy v1.24 Manual
The gradient is computed using second order accurate central differences ... Gradient is calculated using N-th order accurate differences at the boundaries.
Read more >How do I obtain the index list in a NumPy Array of all the NaN ...
You will have to make use of np.isnan along with no.argwhere to achieve what you desire in this question. The following code will...
Read more >NaN is not recognized in pandas after np.where clause. Why ...
Does Knowing Data Structures and Algorithms Benefit a C# Developer? · Improving Your Sales & Marketing Process And Increase Business with Jason ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Interesting… I see the dilemma now.
I think for my usage, I have to conclude that the wrong derivative of sqrt(x)^2 at x=0 is probably the lesser of two evils. Especially because my actual use case is to construct a ‘smooth’ sqrt(x dot x) which is quadratic around x=0, so I’ll never actually evaluate it there. Thanks a lot for the help and the links.
There is something of a gotcha, but it’s not obvious! It’s related to the one mentioned in that comment on #1052 that Jake linked.
I think that should be
nan
because the function isn’t differentiable over the reals at 0. It is differentiable over the nonnegative reals, but there the derivative is 1.0, not 0.0 as computed. (Related: enforcing differentiation conventions usingcustom_jvp
.)