question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

nan when using np.where and grad

See original GitHub issue

Here is the situation:

f = lambda x: np.where(x<1., x, np.sqrt(x))
df = grad(f)
df(1e-9)  # evaluates to 1, as expected
df(0.0)   # evaluates to nan, presumably because the second case of the where is nan in that case
         # but it is the first case that should give the result, not the second!

Not sure if this is expected behavior, but it is surprising. I’m going to work on a work-around, but thought I should document it. Thanks!

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
mrtupekcommented, Feb 22, 2021

Interesting… I see the dilemma now.

I think for my usage, I have to conclude that the wrong derivative of sqrt(x)^2 at x=0 is probably the lesser of two evils. Especially because my actual use case is to construct a ‘smooth’ sqrt(x dot x) which is quadratic around x=0, so I’ll never actually evaluate it there. Thanks a lot for the help and the links.

0reactions
mattjjcommented, Feb 22, 2021

There is something of a gotcha, but it’s not obvious! It’s related to the one mentioned in that comment on #1052 that Jake linked.

print(grad(lambda x: safe_sqrt(x)**2)(0.))  # prints 0.0

I think that should be nan because the function isn’t differentiable over the reals at 0. It is differentiable over the nonnegative reals, but there the derivative is 1.0, not 0.0 as computed. (Related: enforcing differentiation conventions using custom_jvp.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to avoid NaN when using np.where function in python?
You can use fillna(False) . You are using Boolean indexing so always the values ​​corresponding to NaN will be 0
Read more >
jax.numpy.where - JAX documentation
Special care is needed when the x or y input to jax.numpy.where() could have a value of NaN. Specifically, when a gradient is...
Read more >
numpy.gradient — NumPy v1.24 Manual
The gradient is computed using second order accurate central differences ... Gradient is calculated using N-th order accurate differences at the boundaries.
Read more >
How do I obtain the index list in a NumPy Array of all the NaN ...
You will have to make use of np.isnan along with no.argwhere to achieve what you desire in this question. The following code will...
Read more >
NaN is not recognized in pandas after np.where clause. Why ...
Does Knowing Data Structures and Algorithms Benefit a C# Developer? · Improving Your Sales & Marketing Process And Increase Business with Jason ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found