question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

VJP differs from JVP, jacfwd, and jacrev output for high-order ODE derivatives

See original GitHub issue

Hi all 👋

Some context

I am using jax to compute higher-order derivatives D^n y(t_0), n>1, of ODE solutions \dot y = f(y), y(t_0) = y_0, for some given f, y0, and t0. The precise recursion is explained in the example snippet below. See also #520.

Issue

I am noticing that the results of an implementation via vjp differ from the results obtained via jvp, jacfwd, and jacrev (and jet, but I am omitting jet in the comparison below). Have I used vjp incorrectly? I tried to remain as close as possible to https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#the-implementation-of-jacfwd-and-jacrev.

Reproducible example

The following snippet reproduces the anomaly.

import jax
import jax.numpy as jnp

def f(y):
    """Some arbitrary function R^2 -> R^2."""
    return jnp.stack([jnp.dot(y, y), jnp.sin(jnp.dot(y, y))])

#
# All of the functions below compute the recursion
#
# F_{n+1}(-) = (J F_n)(-) @ f(-)
#
# starting at F_0 = f.
#

def jacobian_vector_product(f, f0):

    def next_iterate(x):
        _, y = jax.jvp(f, primals=(x,), tangents=(f0(x),))
        return y

    return next_iterate

# does not work as expected
def vector_jacobian_product(f, f0):

    def next_iterate(x):
        _, dfx_fun = jax.vjp(f, x)
        return dfx_fun(f0(x))[0]

    return next_iterate

def jacobian_forward(f, f0):

    def next_iterate(x):
        return jax.jacfwd(f)(x) @ f0(x)

    return next_iterate

def jacobian_reverse(f, f0):

    def next_iterate(x):
        return jax.jacrev(f)(x) @ f0(x)

    return next_iterate

# Check that they do the same

y0 = jnp.arange(1.0, 3.0)

F, f0 = f, f
for _ in range(4):
    F = jacobian_vector_product(F, f0)
    print(F(y0))
# [6.164303  1.7485797]
# [71.161995 56.623775]
# [ 543.6876 1349.717 ]
# [ 9144.678 24030.408]

# does not work as expected
F, f0 = f, f
for _ in range(4):
    F = vector_jacobian_product(F, f0)
    print(F(y0))
# [ 9.455978 18.911957]
# [49.263916  -5.0995216]
# [214.84796 386.2104 ]
# [6731.6987 9059.133 ]

F, f0 = f, f
for _ in range(4):
    F = jacobian_forward(F, f0)
    print(F(y0))
# [6.164303  1.7485797]
# [71.162    56.623787]
# [ 543.6877 1349.7167]
# [ 9144.677 24030.408]

F, f0 = f, f
for _ in range(4):
    F = jacobian_reverse(F, f0)
    print(F(y0))
# [6.164303  1.7485797]
# [71.162    56.623783]
# [ 543.68774 1349.7172 ]
# [ 9144.677 24030.41 ]

I apologise for a rather lengthy example, but I figured it is useful to compare all of those results.

Python: 3.8.10. Pip freeze:

absl-py==1.0.0
click==8.0.3
flatbuffers==2.0
jax==0.2.28
jaxlib==0.1.76
mccabe==0.6.1
mypy-extensions==0.4.3
numpy==1.22.2
opt-einsum==3.3.0
pathspec==0.9.0
platformdirs==2.5.0
pycodestyle==2.8.0
pyflakes==2.4.0
scipy==1.8.0
six==1.16.0
tomli==2.0.1
typing-extensions==4.0.1

Thanks a lot for your help! 😃

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
soraroscommented, Feb 10, 2022

@pnkraemer In JAX, vjp is implemented in terms of jvp^1:

linearize :: (a -> b) -> a -> (b, T a -o T b)
linearize f =
  let ft = jvp f
   in partial_eval ft

vjp :: (a -> b) -> a -> (b, T b -o T a)
vjp f x =
  let (y, ft) = linearize f x
      ftt     = transpose ft
   in (y, ftt)

What you do in the updated version of vector_jacobian_product is essentially using jvp as

transpose (transpose ft) = ft = jvp f

Actually, if you take a slightly simpler f, vector_jacobian_product generates the same code as your jacobian_vector_product. You could try the following code to see for yourself.

import jaxlib
opt = jaxlib.xla_extension.HloPrintOptions.short_parsable()

def f(y):
  return jnp.sin(y) + jnp.sum(y)

y0 = jnp.arange(2.)
F, f0 = f, f

fjvp = jit(jacobian_vector_product(F, f0))
fvjp = jit(vector_jacobian_product(F, f0))
print(fjvp.lower(y0).compile().compiler_ir()[0].to_string(opt))
print(fvjp.lower(y0).compile().compiler_ir()[0].to_string(opt))
1reaction
soraroscommented, Feb 10, 2022

I don’t see why J[f](x) @ f(x) would be equal to f(x) @ J[f](x) (the vjp case).

Read more comments on GitHub >

github_iconTop Results From Across the Web

The Autodiff Cookbook - JAX documentation
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, ...
Read more >
Time-efficient higher-order forward mode? · Issue #520 - GitHub
I have a use for 4th- and 5th-order derivatives of a scalar-input, vector-output function, namely, regularizing ODEs to be easy to solve.
Read more >
Jacobians, Hessians, hvp, vhp, and more - PyTorch
It is implemented as a composition of our jvp and vmap transforms. jacfwd and jacrev can be substituted for each other but they...
Read more >
AA 203 Recitation #2: JAX and Automatic Differentiation
Specifically, JAX can automatically compute the derivative of a function or composition of functions. As an example, for f(x) = 1.
Read more >
Automatic Differentiation for Stellarator Design
computing the exact numerical derivatives of any differentiable ... What if I want the full Jacobian, not a JVP or VJP?
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found