gradients through np.where when one of branches is nan, #347 except with arctan2
See original GitHub issueI have come across an issue that I think is almost identical to #347, except with arctan2. Here’s a repro:
def test(a):
b = np.arctan2(a,a)
print(b)
temp = np.array([0.,1.,0.])
c = np.where(temp>0.5,0.,b)
print(c)
return np.sum(c)
aa = np.array([3.,0.,-5.])
print(test(aa))
print(grad(test)(aa))
Issue Analytics
- State:
- Created 4 years ago
- Comments:10 (5 by maintainers)
Top Results From Across the Web
negating np.isnan in np.where - Stack Overflow
You need negation ~ right before np.isnan ; np.where return indices where the conditions are true and it's not easy to negate indices...
Read more >numpy.arctan2 — NumPy v1.24 Manual
The quadrant (i.e., branch) is chosen so that arctan2(x1, x2) is the signed angle in radians between the ray ending at the origin...
Read more >The COMSOL Multiphysics User's Guide - ETH Weblog Service
The algorithm uses a gradient-based optimization technique to find ... To collapse all nodes in the model tree, except the top nodes on...
Read more >CasADi Namespace Reference
Represents a branch in an MX tree TODO: Change name of file. ... the scheme of the dae, except for input(DAE_P), which is...
Read more >numpy.arctan2() in Python - GeeksforGeeks
Return : Element-wise arc tangent of arr1/arr2. The values are in the closed interval [-pi / 2, pi / 2]. Code #1 :...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for raising this, and for the clear repro.
It’s actually a surprisingly thorny issue, as we’ve recently realized, and I think the change I made in #383 to fix #347 was misguided.
The fundamental trouble, as @dougalm explained to me, is including
nan
(andinf
too, but let’s focus onnan
) in a system that relies on properties of linear maps. For example, the field / vector space axioms require that0 * x = x * 0 = 0
for anyx
. Butnan
also behaves likenan * x = x * nan = nan
for anyx
. So what value should we assign to an expression like0 * nan
? Should it be0 * nan = nan
or0 * nan = 0
? Unfortunately either choice leads to problems…This comes up in automatic differentiation because of how
nan
Jacobians can arise, as in your example, and interact with zero (co)vectors. The Jacobian oflambda x: np.arctan2(x, x)
evaluated at 0 islambda x: np.nan * x
. That means if we choose the convention that0 * nan = nan
then we’ll getgrad(lambda x: np.arctan2(x, x))(0.)
results innan
. That seems like a sensible outcome, since the mathematical function clearly denoted by that program isn’t differentiable at 0.So far so good with the
0 * nan = nan
convention. How about a program like this one?That also works like you might expect, with
grad(f)(0.) == 0.
. So what goes wrong in your example? Or in this next one, which seems like it should mean the same thing as the one just above?Or even…
And before you start thinking (like I did) that this is a problem only with
np.where
specifically, here’s another instantiation of the same problem without usingnp.where
:That last one is just a funny denotation of the zero function (i.e. a constant function), and we’re still getting
nan
!The trouble with all these, both with
np.where
and the indexing example, is that in some path through the program (e.g. one side of thenp.where
) we’re generating Jacobians likelambda x: x * np.nan
. These paths aren’t “taken” in that they’re selected off by some array-level primitive likenp.where
or indexing, and ultimately that means zero covectors are propagated along them. If nonan
s were involved that’d work great because those branches end up contributing zeros to the final sum-total result. But with the convention that0 * nan = nan
those zero covectors can be turned intonan
s, and including thosenan
covectors in a sum gives us anan
result.From that perspective, it’s almost surprising that the example with the explicit
if
didn’t have any issue. But that’s because in that case the program we differentiate doesn’t represent both paths through the program: we specialize away theif
entirely and only see the “good” path. The trouble with these other examples is we’re writing programs that explicitly represent both sides when specialized out for autodiff, and even though we’re only selecting a “good” side, as we propagate zero covectors through the bad side on the backward pass we start to generatenan
s.Okay, so if we choose
0 * nan = nan
we’ll end up gettingnan
values for gradients of programs that we think denote differentiable mathematical functions. What if we just choose0 * nan = 0
? That’s what #383 did (as an attempted fix for #347), introducinglax._safe_mul
for which0 * x = 0
for anyx
value and using it in some differentiation rules. But it turns out the consequences of this choice are even worse, as @alextp showed me with this example:That should be a
nan
. If it weren’t anan
then the only defensible value is1
since that’s the directional derivative on the right (and there isn’t one on the left). But we’re giving 0, and that’s a silently incorrect derivative, the worst sin a system can commit. And that incorrect behavior comes from choosing0 * nan = 0
(i.e. from `lax._safe_mul and #383).So I plan to revert #383 and go back to producing
nan
values as the lesser of two evils.The only solution we’ve come up with so far to achieve both criteria (i.e. to produce non-
nan
derivatives for programs that involve selecting off non-differentiable branches, and not to produce incorrect zero derivatives for non-differentiable programs where we should instead getnan
) is pretty heavyweight, and is something like tracking a symbolic zero mask potentially through the entire backward pass of differentiation. (These issues don’t come up in forward-mode.) That solution sounds heavy both in terms of implementation and in terms of runtime work to be done.Once we revert #383, if you have programs (like the one in the OP) that you want to express in a differentiable way, one workaround might be to write things in terms of a
vectorized_cond
function instead ofnp.where
, maybe something like this:But that’s clumsy, and doesn’t solve the indexing version of the problem (i.e.
grad(lambda x: np.array([0., 1./x])[0])(0.)
). Only the heavyweight solution would handle that, as far as we know. Maybe we could implement it and make it opt-in, like--differentiate_more_programs_but_slowly
…WDYT?
It might be worth adding the double where trick to JAX - the sharp bits. (I looked there first, before eventually finding an explanation for my error here.)