Support autodiff of Eigendecomposition with repeated eigenvalues
See original GitHub issueOn v0.1.25 on OSX, I get the following error when computing gradients from the following jit-compiled function.
import numpy as onp
import jax.numpy as np
from jax import grad, jit
def test(x):
val, vec = np.linalg.eigh(x)
return np.real(np.sum(val))
grad_test = jit(grad(test))
grad_test_jc = jit(grad(jit(test)))
x = onp.eye(3, dtype=onp.double)
xc = onp.eye(3, dtype=onp.complex)
print(test(x))
print(grad_test(x))
print(grad_test_jc(x))
print(grad_test(xc))
3.0
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
[[1.+0.j 0.+0.j 0.+0.j]
[0.+0.j 1.+0.j 0.+0.j]
[0.+0.j 0.+0.j 1.+0.j]]
So far so good. But computing the gradient of the jit-compiled function with complex inputs errors
print(grad_test_jc(xc))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-1-10b24cdf8a93> in <module>
19
20
---> 21 print(grad_test_jc(xc))
/usr/local/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
105 _check_args(jaxtupletree_args)
106 jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
--> 107 jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
108 return build_tree(out_tree(), jaxtupletree_out)
109
/usr/local/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **kwargs)
543 if top_trace is None:
544 with new_sublevel():
--> 545 ans = primitive.impl(f, *args, **kwargs)
546 else:
547 tracers = map(top_trace.full_raise, args)
/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args)
452 fun, out_tree = flatten_fun(fun, in_trees)
453
--> 454 compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
455 try:
456 flat_ans = compiled_fun(*flat_args)
/usr/local/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(f, *args)
206 if len(cache) > max_size:
207 cache.popitem(last=False)
--> 208 ans = call(f, *args)
209 cache[key] = (ans, f)
210 return ans
/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_callable(fun, *abstract_args)
473 jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
474 assert not env # no subtraces here (though cond might eventually need them)
--> 475 compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
476 del master, consts, jaxpr, env
477 handle_result = result_handler(result_shape)
/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in compile_jaxpr(jaxpr, const_vals, *abstract_args)
135 def compile_jaxpr(jaxpr, const_vals, *abstract_args):
136 arg_shapes = list(map(xla_shape, abstract_args))
--> 137 built_c = jaxpr_computation(jaxpr, const_vals, (), *arg_shapes)
138 result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
139 return built_c.Compile(arg_shapes, xb.get_compile_options(),
/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in jaxpr_computation(jaxpr, const_vals, freevar_shapes, *arg_shapes)
173 map(c.GetShape, map(read, const_bindings + freevar_bindings)),
174 *in_shapes)
--> 175 for subjaxpr, const_bindings, freevar_bindings in eqn.bound_subjaxprs]
176 subfuns = [(subc, tuple(map(read, const_bindings + freevar_bindings)))
177 for subc, (_, const_bindings, freevar_bindings)
/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in <listcomp>(.0)
173 map(c.GetShape, map(read, const_bindings + freevar_bindings)),
174 *in_shapes)
--> 175 for subjaxpr, const_bindings, freevar_bindings in eqn.bound_subjaxprs]
176 subfuns = [(subc, tuple(map(read, const_bindings + freevar_bindings)))
177 for subc, (_, const_bindings, freevar_bindings)
/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in jaxpr_computation(jaxpr, const_vals, freevar_shapes, *arg_shapes)
167 for eqn in jaxpr.eqns:
168 in_nodes = map(read, eqn.invars)
--> 169 in_shapes = map(c.GetShape, in_nodes)
170 subcs = [
171 jaxpr_computation(
/usr/local/lib/python3.7/site-packages/jax/util.py in safe_map(f, *args)
41 for arg in args[1:]:
42 assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 43 return list(map(f, *args))
44
45
/usr/local/lib/python3.7/site-packages/jaxlib/xla_client.py in GetShape(self, operand)
876
877 def GetShape(self, operand):
--> 878 return _wrap_shape(self._builder.GetShape(operand))
879
880 def SetOpMetadata(self, op_metadata):
RuntimeError: Invalid argument: Binary op add with different element types: c64[3,3] and f32[1,3].
Jax built from source produced the same error.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:30 (16 by maintainers)
Top Results From Across the Web
Derivative of eigenvalues and eigenvectors of Hermitian ...
I am recently looking into automatic differentiation, as implemented by ForwardDiff.jl. The manual says functions that call blas are not ...
Read more >Automatic differentiation of dominant eigensolver and its ...
Abstract. We investigate the automatic differentiation of dominant eigensolver where only a small proportion of eigenvalues and corresponding ...
Read more >Computation of Derivatives of Repeated Eigenvalues and ...
This paper presents and analyzes new algorithms for computing the numerical values of derivatives, of arbitrary order, and of eigenvalues and eigenvectors ......
Read more >Gradient and Hessian of Abs(Non-Repeated Eigenvalue) ...
I would like to compute in MATLAB, without resort to automatic differentiation), the gradient, and ideally also the Hessian, of the absolute ...
Read more >Repeated Eigenvalues - YouTube
When solving a system of linear first order differential equations, if the eigenvalues are repeated, we need a slightly different form of ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
It’s probably worth noting that the example failure case for eigenvector derivatives from https://github.com/google/jax/issues/669#issuecomment-489303348 is not well-defined matrix-valued function:
E.g., suppose
x = np.eye(2)
. Then normalized eigenvectorsvec
could be either[[1, 0], [0, 1]]
or[[1/sqrt(2), 1/sqrt(2)], [1/sqrt(2), -1/sqrt(2)]]
, sotest(x)
could be either2
or2/sqrt(2)
.I came across this conversation and wanted to leave this note as a reference for near-degenerate eigenvectors, specifically the transformation of Eq. 10 into Eq. 11 to account for near-degeneracy: https://github.com/mitmath/18335/blob/master/notes/adjoint/eigenvalue-adjoint.pdf . Hopefully you find it useful, though some translation may be in order.