expm shape tracing issue through two vmaps and jvp
See original GitHub issueHi,
I’ve been trying to narrow down the issue that arises when a vmap is applied to jvp that goes through a calculation which has its own vmap. Although I am new to Jax, and potentially, lacking understanding of shape tracing but it seems to me that the problem occurs due to expm matrix operation.
Consider the following code that either implements a Jacobian calculation for the matrix product operation:
e^{I*(w[0,0]+w[0,1]})e^{I*(w[1,0]+w[1,1]})
or
(I * w[0,0]) * (I * w[0,1]) * (I * w[1,0]) * (I * w[1,1])
where I is 2x2 identity matrix and w is a set of weights stored as a 2x2 array. The executed calculation can be toggled by commenting/uncommenting 2 lines in matrix_operation(). The code works well for the 2nd operation but breaks down for the 1st one throwing an assertion error related to the shape detected in lax_control_flow.py.
I am wondering whether it is a bug in expm or I am missing something? I would appreciate any help!
import jax.numpy as jnp
from jax import vmap, jvp
from functools import partial
from jax.numpy.linalg import multi_dot
def matrix_operation(w):
assert w.shape == (2,)
A = jnp.identity(2)
return la.expm(A*(w[0]+w[1])) # ----- breaks down with this line
#return jnp.matmul(A*w[0], A*w[1]) #----- works with this line!
def cost_function(ws):
same_w = vmap(matrix_operation)(ws)
product = multi_dot(same_w)
return product
def pushfwd_(func, weights, tangent):
return jvp(func, (weights,), (tangent,))
r = 2
c = 2
weights = jnp.ones((r, c))
pushfwd = partial(pushfwd_, cost_function, weights)
# a set of vectors with a single non-zero entry equal to 1
# and same shape as weights
tangents = jnp.reshape(jnp.identity(r*c), (c*r, r, c))
print(vmap(pushfwd)(tangents)) # this breaks with la.expm above
#pushfwd(tangents[0]) # this works with la.expm above
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (4 by maintainers)
What specific error message do you see?
The short answer for what’s wrong here is probably that JAX’s
expm
doesn’t support reverse-mode differentiation yet.I think this is a bug in our while_loop batching rule (and a while_loop is used in expm, via the fori_loop wrapper)! It’s a bit hard to articulate, but I think I see it…