question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

expm shape tracing issue through two vmaps and jvp

See original GitHub issue

Hi,

I’ve been trying to narrow down the issue that arises when a vmap is applied to jvp that goes through a calculation which has its own vmap. Although I am new to Jax, and potentially, lacking understanding of shape tracing but it seems to me that the problem occurs due to expm matrix operation.

Consider the following code that either implements a Jacobian calculation for the matrix product operation:

e^{I*(w[0,0]+w[0,1]})e^{I*(w[1,0]+w[1,1]})

or

(I * w[0,0]) * (I * w[0,1]) * (I * w[1,0]) * (I * w[1,1])

where I is 2x2 identity matrix and w is a set of weights stored as a 2x2 array. The executed calculation can be toggled by commenting/uncommenting 2 lines in matrix_operation(). The code works well for the 2nd operation but breaks down for the 1st one throwing an assertion error related to the shape detected in lax_control_flow.py.

I am wondering whether it is a bug in expm or I am missing something? I would appreciate any help!

import jax.numpy as jnp
from jax import vmap, jvp
from functools import partial
from jax.numpy.linalg import multi_dot

def matrix_operation(w):
    assert w.shape == (2,)
    A = jnp.identity(2)
    return la.expm(A*(w[0]+w[1])) # -----  breaks down with this line
    #return jnp.matmul(A*w[0], A*w[1])  #----- works with this line! 
    
def cost_function(ws):
    same_w = vmap(matrix_operation)(ws)
    product = multi_dot(same_w)
    return product

def pushfwd_(func, weights, tangent): 
    return jvp(func, (weights,), (tangent,))

r = 2
c = 2
weights = jnp.ones((r, c))

pushfwd = partial(pushfwd_, cost_function, weights)

# a set of vectors with a single non-zero entry equal to 1
# and same shape as weights
tangents = jnp.reshape(jnp.identity(r*c), (c*r, r, c))

print(vmap(pushfwd)(tangents)) # this breaks with la.expm above
#pushfwd(tangents[0]) # this works with la.expm above

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
shoyercommented, May 20, 2020

What specific error message do you see?

The short answer for what’s wrong here is probably that JAX’s expm doesn’t support reverse-mode differentiation yet.

0reactions
mattjjcommented, May 21, 2020

I think this is a bug in our while_loop batching rule (and a while_loop is used in expm, via the fori_loop wrapper)! It’s a bit hard to articulate, but I think I see it…

Read more comments on GitHub >

github_iconTop Results From Across the Web

custom batching (vmap) · Issue #9073 · google/jax - GitHub
The underlying primitive currently stages out the custom-batched function eagerly. We may want to move to a delayed tracing approach. We're also ...
Read more >
JAX Frequently Asked Questions (FAQ)
How to use jit with methods? ... Is JAX faster than NumPy? ... Why are gradients zero for functions based on sort order?...
Read more >
Ibtesam Ahmed | Discussion Expert | Kaggle
~Initially, I had some trouble understanding how the two-step evaluation would actually work. So, I'll break it down here a little for others...
Read more >
arXiv:2105.15183v5 [cs.LG] 12 Oct 2022
In this approach, the user defines directly in. Python a mapping function F capturing the optimality conditions of the problem solved by the....
Read more >
Automatic differentiation - Mathieu Blondel
Jf(x)v ≈ f(x + εv) − f(x) ε. Computing the JVP approximately by (central) finite difference requires only 2 calls to f. Mathieu...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found