question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

gradients through np.where when one of branches is nan, #347 except with arctan2

See original GitHub issue

I have come across an issue that I think is almost identical to #347, except with arctan2. Here’s a repro:

def test(a):
  b = np.arctan2(a,a)
  print(b)
  
  temp = np.array([0.,1.,0.])
  c = np.where(temp>0.5,0.,b)
  print(c)
  
  return np.sum(c)

aa = np.array([3.,0.,-5.])
print(test(aa))
print(grad(test)(aa))

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:10 (5 by maintainers)

github_iconTop GitHub Comments

13reactions
mattjjcommented, Jul 23, 2019

Thanks for raising this, and for the clear repro.

It’s actually a surprisingly thorny issue, as we’ve recently realized, and I think the change I made in #383 to fix #347 was misguided.

The fundamental trouble, as @dougalm explained to me, is including nan (and inf too, but let’s focus on nan) in a system that relies on properties of linear maps. For example, the field / vector space axioms require that 0 * x = x * 0 = 0 for any x. But nan also behaves like nan * x = x * nan = nan for any x. So what value should we assign to an expression like 0 * nan? Should it be 0 * nan = nan or 0 * nan = 0? Unfortunately either choice leads to problems…

This comes up in automatic differentiation because of how nan Jacobians can arise, as in your example, and interact with zero (co)vectors. The Jacobian of lambda x: np.arctan2(x, x) evaluated at 0 is lambda x: np.nan * x. That means if we choose the convention that 0 * nan = nan then we’ll get grad(lambda x: np.arctan2(x, x))(0.) results in nan. That seems like a sensible outcome, since the mathematical function clearly denoted by that program isn’t differentiable at 0.

So far so good with the 0 * nan = nan convention. How about a program like this one?

def f(x):
  if x > 0.5:
    return np.arctan2(x, x)
  else:
    return 0.

# grad(f)(0.) ==> 0.

That also works like you might expect, with grad(f)(0.) == 0.. So what goes wrong in your example? Or in this next one, which seems like it should mean the same thing as the one just above?

def f(x):
  return np.where(x > 0.5, np.arctan2(x, x), 0.)

# grad(f)(0.) ==> nan

Or even…

def f(x):
  return np.where(False, np.arctan2(x, x), 0.)

# grad(f)(0.) ==> nan

And before you start thinking (like I did) that this is a problem only with np.where specifically, here’s another instantiation of the same problem without using np.where:

grad(lambda x: np.array([0., 1./x])[0])(0.)  # ==> nan

That last one is just a funny denotation of the zero function (i.e. a constant function), and we’re still getting nan!

The trouble with all these, both with np.where and the indexing example, is that in some path through the program (e.g. one side of the np.where) we’re generating Jacobians like lambda x: x * np.nan. These paths aren’t “taken” in that they’re selected off by some array-level primitive like np.where or indexing, and ultimately that means zero covectors are propagated along them. If no nans were involved that’d work great because those branches end up contributing zeros to the final sum-total result. But with the convention that 0 * nan = nan those zero covectors can be turned into nans, and including those nan covectors in a sum gives us a nan result.

From that perspective, it’s almost surprising that the example with the explicit if didn’t have any issue. But that’s because in that case the program we differentiate doesn’t represent both paths through the program: we specialize away the if entirely and only see the “good” path. The trouble with these other examples is we’re writing programs that explicitly represent both sides when specialized out for autodiff, and even though we’re only selecting a “good” side, as we propagate zero covectors through the bad side on the backward pass we start to generate nans.

Okay, so if we choose 0 * nan = nan we’ll end up getting nan values for gradients of programs that we think denote differentiable mathematical functions. What if we just choose 0 * nan = 0? That’s what #383 did (as an attempted fix for #347), introducing lax._safe_mul for which 0 * x = 0 for any x value and using it in some differentiation rules. But it turns out the consequences of this choice are even worse, as @alextp showed me with this example:

grad(lambda x: np.sqrt(x)**2)(0.)  # 0. !!!

That should be a nan. If it weren’t a nan then the only defensible value is 1 since that’s the directional derivative on the right (and there isn’t one on the left). But we’re giving 0, and that’s a silently incorrect derivative, the worst sin a system can commit. And that incorrect behavior comes from choosing 0 * nan = 0 (i.e. from `lax._safe_mul and #383).

So I plan to revert #383 and go back to producing nan values as the lesser of two evils.

The only solution we’ve come up with so far to achieve both criteria (i.e. to produce non-nan derivatives for programs that involve selecting off non-differentiable branches, and not to produce incorrect zero derivatives for non-differentiable programs where we should instead get nan) is pretty heavyweight, and is something like tracking a symbolic zero mask potentially through the entire backward pass of differentiation. (These issues don’t come up in forward-mode.) That solution sounds heavy both in terms of implementation and in terms of runtime work to be done.

Once we revert #383, if you have programs (like the one in the OP) that you want to express in a differentiable way, one workaround might be to write things in terms of a vectorized_cond function instead of np.where, maybe something like this:

def vectorized_cond(pred, true_fun, false_fun, operand):
  # true_fun and false_fun must act elementwise (i.e. be vectorized)
  true_op = np.where(pred, operand, 0)
  false_op = np.where(pred, 0, operand)
  return np.where(pred, true_fun(true_op), false_fun(false_op))

# no nans, even after reverting #383
grad(lambda x: vectorized_cond(x > 0.5, lambda x: np.arctan2(x, x), lambda x: 0., x))(0.)

But that’s clumsy, and doesn’t solve the indexing version of the problem (i.e. grad(lambda x: np.array([0., 1./x])[0])(0.)). Only the heavyweight solution would handle that, as far as we know. Maybe we could implement it and make it opt-in, like --differentiate_more_programs_but_slowly

WDYT?

1reaction
Sohl-Dicksteincommented, Oct 29, 2022

It might be worth adding the double where trick to JAX - the sharp bits. (I looked there first, before eventually finding an explanation for my error here.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

negating np.isnan in np.where - Stack Overflow
You need negation ~ right before np.isnan ; np.where return indices where the conditions are true and it's not easy to negate indices...
Read more >
numpy.arctan2 — NumPy v1.24 Manual
The quadrant (i.e., branch) is chosen so that arctan2(x1, x2) is the signed angle in radians between the ray ending at the origin...
Read more >
The COMSOL Multiphysics User's Guide - ETH Weblog Service
The algorithm uses a gradient-based optimization technique to find ... To collapse all nodes in the model tree, except the top nodes on...
Read more >
CasADi Namespace Reference
Represents a branch in an MX tree TODO: Change name of file. ... the scheme of the dae, except for input(DAE_P), which is...
Read more >
numpy.arctan2() in Python - GeeksforGeeks
Return : Element-wise arc tangent of arr1/arr2. The values are in the closed interval [-pi / 2, pi / 2]. Code #1 :...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found