question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Support autodiff of Eigendecomposition with repeated eigenvalues

See original GitHub issue

On v0.1.25 on OSX, I get the following error when computing gradients from the following jit-compiled function.

import numpy as onp
import jax.numpy as np
from jax import grad, jit

def test(x):
    val, vec = np.linalg.eigh(x)
    return np.real(np.sum(val))

grad_test = jit(grad(test))
grad_test_jc = jit(grad(jit(test)))

x = onp.eye(3, dtype=onp.double)
xc = onp.eye(3, dtype=onp.complex)

print(test(x))
print(grad_test(x))
print(grad_test_jc(x))
print(grad_test(xc))
3.0
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[[1.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 1.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 1.+0.j]]

So far so good. But computing the gradient of the jit-compiled function with complex inputs errors

print(grad_test_jc(xc))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-1-10b24cdf8a93> in <module>
     19 
     20 
---> 21 print(grad_test_jc(xc))

/usr/local/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    105     _check_args(jaxtupletree_args)
    106     jaxtree_fun, out_tree = pytree_fun_to_jaxtupletree_fun(f, in_trees)
--> 107     jaxtupletree_out = xla.xla_call(jaxtree_fun, *jaxtupletree_args)
    108     return build_tree(out_tree(), jaxtupletree_out)
    109 

/usr/local/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, f, *args, **kwargs)
    543   if top_trace is None:
    544     with new_sublevel():
--> 545       ans = primitive.impl(f, *args, **kwargs)
    546   else:
    547     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args)
    452   fun, out_tree = flatten_fun(fun, in_trees)
    453 
--> 454   compiled_fun = xla_callable(fun, *map(abstractify, flat_args))
    455   try:
    456     flat_ans = compiled_fun(*flat_args)

/usr/local/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(f, *args)
    206       if len(cache) > max_size:
    207         cache.popitem(last=False)
--> 208       ans = call(f, *args)
    209       cache[key] = (ans, f)
    210     return ans

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in xla_callable(fun, *abstract_args)
    473     jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master).call_wrapped(pvals)
    474     assert not env  # no subtraces here (though cond might eventually need them)
--> 475     compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
    476     del master, consts, jaxpr, env
    477   handle_result = result_handler(result_shape)

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in compile_jaxpr(jaxpr, const_vals, *abstract_args)
    135 def compile_jaxpr(jaxpr, const_vals, *abstract_args):
    136   arg_shapes = list(map(xla_shape, abstract_args))
--> 137   built_c = jaxpr_computation(jaxpr, const_vals, (), *arg_shapes)
    138   result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
    139   return built_c.Compile(arg_shapes, xb.get_compile_options(),

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in jaxpr_computation(jaxpr, const_vals, freevar_shapes, *arg_shapes)
    173             map(c.GetShape, map(read, const_bindings + freevar_bindings)),
    174             *in_shapes)
--> 175         for subjaxpr, const_bindings, freevar_bindings in eqn.bound_subjaxprs]
    176     subfuns = [(subc, tuple(map(read, const_bindings + freevar_bindings)))
    177                for subc, (_, const_bindings, freevar_bindings)

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in <listcomp>(.0)
    173             map(c.GetShape, map(read, const_bindings + freevar_bindings)),
    174             *in_shapes)
--> 175         for subjaxpr, const_bindings, freevar_bindings in eqn.bound_subjaxprs]
    176     subfuns = [(subc, tuple(map(read, const_bindings + freevar_bindings)))
    177                for subc, (_, const_bindings, freevar_bindings)

/usr/local/lib/python3.7/site-packages/jax/interpreters/xla.py in jaxpr_computation(jaxpr, const_vals, freevar_shapes, *arg_shapes)
    167   for eqn in jaxpr.eqns:
    168     in_nodes = map(read, eqn.invars)
--> 169     in_shapes = map(c.GetShape, in_nodes)
    170     subcs = [
    171         jaxpr_computation(

/usr/local/lib/python3.7/site-packages/jax/util.py in safe_map(f, *args)
     41   for arg in args[1:]:
     42     assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
---> 43   return list(map(f, *args))
     44 
     45 

/usr/local/lib/python3.7/site-packages/jaxlib/xla_client.py in GetShape(self, operand)
    876 
    877   def GetShape(self, operand):
--> 878     return _wrap_shape(self._builder.GetShape(operand))
    879 
    880   def SetOpMetadata(self, op_metadata):

RuntimeError: Invalid argument: Binary op add with different element types: c64[3,3] and f32[1,3].

Jax built from source produced the same error.

Issue Analytics

  • State:open
  • Created 4 years ago
  • Reactions:2
  • Comments:30 (16 by maintainers)

github_iconTop GitHub Comments

2reactions
shoyercommented, Sep 29, 2020

It’s probably worth noting that the example failure case for eigenvector derivatives from https://github.com/google/jax/issues/669#issuecomment-489303348 is not well-defined matrix-valued function:

def test(x):
    val, vec = np.linalg.eigh(x)
    return np.real(np.sum(val))

E.g., suppose x = np.eye(2). Then normalized eigenvectors vec could be either [[1, 0], [0, 1]] or [[1/sqrt(2), 1/sqrt(2)], [1/sqrt(2), -1/sqrt(2)]], so test(x) could be either 2 or 2/sqrt(2).

1reaction
platawieccommented, Mar 19, 2021

I came across this conversation and wanted to leave this note as a reference for near-degenerate eigenvectors, specifically the transformation of Eq. 10 into Eq. 11 to account for near-degeneracy: https://github.com/mitmath/18335/blob/master/notes/adjoint/eigenvalue-adjoint.pdf . Hopefully you find it useful, though some translation may be in order.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Derivative of eigenvalues and eigenvectors of Hermitian ...
I am recently looking into automatic differentiation, as implemented by ForwardDiff.jl. The manual says functions that call blas are not ...
Read more >
Automatic differentiation of dominant eigensolver and its ...
Abstract. We investigate the automatic differentiation of dominant eigensolver where only a small proportion of eigenvalues and corresponding ...
Read more >
Computation of Derivatives of Repeated Eigenvalues and ...
This paper presents and analyzes new algorithms for computing the numerical values of derivatives, of arbitrary order, and of eigenvalues and eigenvectors ......
Read more >
Gradient and Hessian of Abs(Non-Repeated Eigenvalue) ...
I would like to compute in MATLAB, without resort to automatic differentiation), the gradient, and ideally also the Hessian, of the absolute ...
Read more >
Repeated Eigenvalues - YouTube
When solving a system of linear first order differential equations, if the eigenvalues are repeated, we need a slightly different form of ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found