Eigh gradients can be wrong
See original GitHub issueIn a lot of cases, eigh doesn’t return the correct gradients. I attached a very simple example, comparing JAX gradients (incorrect ones) with autograd gradients (correct ones),
JAX gradients computation:
from jax import grad
import jax.numpy as jnp
a = jnp.array([[1.93865817, 0.35509264, 0.64405863, 1.3430815 , 0.92342772],
[0.35509264, 1.30340895, 0.90821576, 0.7199631 , 1.49002679],
[0.64405863, 0.90821576, 0.50616539, 0.97171981, 0.9621489 ],
[1.3430815 , 0.7199631 , 0.97171981, 1.27540148, 1.44715998],
[0.92342772, 1.49002679, 0.9621489 , 1.44715998, 1.75984414]])
def fun(x):
e,v = jnp.linalg.eigh(x)
return jnp.sum(jnp.abs(v))
diff_fun = (grad(fun))
print(diff_fun(a))
JAX Output: (incorrect)
[[-0.33112606 0.32743236 -0.18718082 0.643335 -0.30847853] [ 0.32743236 0.5745873 -0.07279176 -0.0815831 -0.50257957] [-0.18718082 -0.07279176 -0.4819119 0.29376295 0.19632888] [ 0.643335 -0.0815831 0.29376295 -0.4126568 -0.28805062] [-0.30847853 -0.50257957 0.19632888 -0.28805062 0.65110785]]
Autograd gradients computation script:
import autograd.numpy as npa
from autograd import elementwise_grad as grad
a = npa.array([[1.93865817, 0.35509264, 0.64405863, 1.3430815 , 0.92342772],
[0.35509264, 1.30340895, 0.90821576, 0.7199631 , 1.49002679],
[0.64405863, 0.90821576, 0.50616539, 0.97171981, 0.9621489 ],
[1.3430815 , 0.7199631 , 0.97171981, 1.27540148, 1.44715998],
[0.92342772, 1.49002679, 0.9621489 , 1.44715998, 1.75984414]])
def fun(x):
e,v = npa.linalg.eigh(x)
return npa.sum(npa.abs(v))
diff_fun = (grad(fun))
print(diff_fun(a))
Autograd Output: (correct)
[[-0.33112716 0.52218473 0.7270792 0.67123525 -1.13012818] [ 0.13268013 0.57458894 1.47796036 0.09381337 -1.45197577] [-1.10143499 -1.62353805 -0.48191193 0.63382524 1.86241553] [ 0.61543807 -0.25697898 -0.04630763 -0.41265955 0.06441824] [ 0.51316611 0.44681027 -1.46975969 -0.64051274 0.65110971]]
It’s worth noting that the mistake maybe related exclusively to the eigenvectors. When using a primitive function that only uses the eigenvalues, JAX returns the correct gradients, for example:
def fun(x):
e,v = jnp.linalg.eigh(x)
return jnp.sum(jnp.abs(e))
Issue Analytics
- State:
- Created a year ago
- Comments:5 (1 by maintainers)
I notice that the gradient output of Jax is just the symmetric part of that of autograd.
The output is
Since the orginal matrix
a
to be diagonalized is symmetric, this discrepancy between Jax and autograd results is not really a bug. In practical applications, the symmetry property ofa
is usually guaranteed by some upstream computations, and whether the adjoint ofa
itself is symmetric or not would not affect those gradients of real interest.Notable observation: the V-matrix in the eigenvalue decomposition is not unique because $VΛV^⊤ = (V D) Λ (VD)^⊤$ for any matrix $D=diag(±1)$.
Running your code, we can see that the output for Jax is
whereas for autograd it is
I.e. $D = diag([-1, +1, -1, +1, -1])$. Of course $‖V‖₁,₁=∑ᵢⱼ|Vᵢⱼ|$ and the gradient should be identical either way.