question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Possible bug for jax.numpy.abs complex-input grad

See original GitHub issue

Please:

  • Check for duplicate issues. (tried my best to find a duplicate on this)
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
In [1]: import jax; import jax.numpy as jnp
In [2]: t1 = jnp.array([[1.0, 2.0], [0.0, -1.0]])
In [3]: jax.grad(lambda t: jnp.sum(jnp.abs(t)))(t1)
Out[3]: 
DeviceArray([[ 1.,  1.],
             [ 1., -1.]], dtype=float32)
In [7]: t2 = t1 * 1+0j # this is so we convert to complex64
In [8]: jax.grad(lambda t: jnp.sum(jnp.abs(t)))(t2)
Out[8]: 
DeviceArray([[ 1.+0.j,  1.+0.j],
             [ 0.+0.j, -1.+0.j]], dtype=complex64)

In the code sequence above, we can compare the grad of basically the same input as a float tensor and as a complex tensor. The problem I have (and for which I could not find any references) is that when an entry is 0, the grad changes from 1 (at float input) to 0 (at complex input), even though we’re talking about the same abs function applied to said entry.

From looking at the source code, this seems like a matter of using select to replace the result at the positions which were originally zero with 1, after the calculation. I’d be happy to submit a PR if this is indeed a bug. Otherwise, perhaps a reference on why this is intended would be great!

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, May 23, 2022

Thanks for the question!

As you’ve already discussed, the observation which started this issue comes down to this:

import jax
import jax.numpy as jnp

print(jax.grad(jnp.abs)(0.))  # 1.0
print(jax.grad(jnp.abs)(0. + 0.j))  # 0j

Also as already discussed, the mathematical function being represented here (complex absolute value and its restriction to real numbers) isn’t differentiable at the point we’re evaluating the gradient. So we could return a nan here to represent the non-differentiability. But instead we usually just provide some implementation-defined answer which seems expedient. Yet our implementation-defined answer differs in the real- and complex-input cases, not for deep reasons but just an artifact of how we happened to write the abs JVP rule.

I think there are four options:

  1. make grad(abs)(0.) and grad(abs)(0.+0.j) both be nan
  2. make grad(abs)(0.) == grad(abs)(0.+0.j) == 1. (currently what we do for real-valued inputs only)
  3. make grad(abs)(0.) == grad(abs)(0.+0.j) == 0. (currently what we do for complex-valued inputs only)
  4. do nothing, keeping grad(abs)(0.) == 1. and grad(abs)(0.+0.j) == 0..

I bet if we choose option 1 we’ll break existing code (which I’m guessing relies on not producing nans here). If we choose option 3 we’d be changing the behavior for real-valued inputs, which I suspect is much more common.

Both options 2 and 4 seem viable. Let’s try 2 and see if anyone complains!

1reaction
polvalentecommented, May 2, 2022

@YouJiacheng exactly that. The issue is that the same number is differentiating differently even though only the datatype changed, not the number itself.

I just continue to be unsure if this is a decision or just a side-effect of how it was implemented, which is the main reason on why I opened this issue.

re the edit above: I get that the complex extension opens up more paths for the derivative to be checked. If efficiency is the sole reason why they are different, perhaps a comment on the jvp code might suffice.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.numpy.abs - JAX documentation - Read the Docs
For complex input, a + ib , the absolute value is a 2 + b 2 . This is a scalar if x...
Read more >
Change log - JAX documentation
Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to ... jax.numpy.linalg.cond() on TPUs now accepts complex...
Read more >
JAX Frequently Asked Questions (FAQ)
How to use jit with methods? ... Is JAX faster than NumPy? ... Why are gradients zero for functions based on sort order?...
Read more >
JAX Errors - JAX documentation - Read the Docs
from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x): ... In many cases it is possible to work...
Read more >
symjax - arXiv
JAX is a python interface that provides a Numpy-like software on top of XLA and providing just-in-time compilation.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found