Possible bug for jax.numpy.abs complex-input grad
See original GitHub issuePlease:
- Check for duplicate issues. (tried my best to find a duplicate on this)
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
In [1]: import jax; import jax.numpy as jnp
In [2]: t1 = jnp.array([[1.0, 2.0], [0.0, -1.0]])
In [3]: jax.grad(lambda t: jnp.sum(jnp.abs(t)))(t1)
Out[3]:
DeviceArray([[ 1., 1.],
[ 1., -1.]], dtype=float32)
In [7]: t2 = t1 * 1+0j # this is so we convert to complex64
In [8]: jax.grad(lambda t: jnp.sum(jnp.abs(t)))(t2)
Out[8]:
DeviceArray([[ 1.+0.j, 1.+0.j],
[ 0.+0.j, -1.+0.j]], dtype=complex64)
In the code sequence above, we can compare the grad of basically the same input as a float tensor and as a complex tensor.
The problem I have (and for which I could not find any references) is that when an entry is 0, the grad changes from 1 (at float input) to 0 (at complex input), even though we’re talking about the same abs
function applied to said entry.
From looking at the source code, this seems like a matter of using select
to replace the result at the positions which were originally zero with 1, after the calculation. I’d be happy to submit a PR if this is indeed a bug. Otherwise, perhaps a reference on why this is intended would be great!
Issue Analytics
- State:
- Created a year ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
jax.numpy.abs - JAX documentation - Read the Docs
For complex input, a + ib , the absolute value is a 2 + b 2 . This is a scalar if x...
Read more >Change log - JAX documentation
Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to ... jax.numpy.linalg.cond() on TPUs now accepts complex...
Read more >JAX Frequently Asked Questions (FAQ)
How to use jit with methods? ... Is JAX faster than NumPy? ... Why are gradients zero for functions based on sort order?...
Read more >JAX Errors - JAX documentation - Read the Docs
from jax import jit >>> import jax.numpy as jnp >>> @jit ... def func(x): ... In many cases it is possible to work...
Read more >symjax - arXiv
JAX is a python interface that provides a Numpy-like software on top of XLA and providing just-in-time compilation.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for the question!
As you’ve already discussed, the observation which started this issue comes down to this:
Also as already discussed, the mathematical function being represented here (complex absolute value and its restriction to real numbers) isn’t differentiable at the point we’re evaluating the gradient. So we could return a
nan
here to represent the non-differentiability. But instead we usually just provide some implementation-defined answer which seems expedient. Yet our implementation-defined answer differs in the real- and complex-input cases, not for deep reasons but just an artifact of how we happened to write theabs
JVP rule.I think there are four options:
grad(abs)(0.)
andgrad(abs)(0.+0.j)
both benan
grad(abs)(0.) == grad(abs)(0.+0.j) == 1.
(currently what we do for real-valued inputs only)grad(abs)(0.) == grad(abs)(0.+0.j) == 0.
(currently what we do for complex-valued inputs only)grad(abs)(0.) == 1.
andgrad(abs)(0.+0.j) == 0.
.I bet if we choose option 1 we’ll break existing code (which I’m guessing relies on not producing
nan
s here). If we choose option 3 we’d be changing the behavior for real-valued inputs, which I suspect is much more common.Both options 2 and 4 seem viable. Let’s try 2 and see if anyone complains!
@YouJiacheng exactly that. The issue is that the same number is differentiating differently even though only the datatype changed, not the number itself.
I just continue to be unsure if this is a decision or just a side-effect of how it was implemented, which is the main reason on why I opened this issue.
re the edit above: I get that the complex extension opens up more paths for the derivative to be checked. If efficiency is the sole reason why they are different, perhaps a comment on the jvp code might suffice.