VJP differs from JVP, jacfwd, and jacrev output for high-order ODE derivatives
See original GitHub issueHi all 👋
Some context
I am using jax
to compute higher-order derivatives D^n y(t_0), n>1, of ODE solutions \dot y = f(y), y(t_0) = y_0, for some given f, y0, and t0. The precise recursion is explained in the example snippet below. See also #520.
Issue
I am noticing that the results of an implementation via vjp
differ from the results obtained via jvp
, jacfwd
, and jacrev
(and jet
, but I am omitting jet
in the comparison below).
Have I used vjp
incorrectly? I tried to remain as close as possible to https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#the-implementation-of-jacfwd-and-jacrev.
Reproducible example
The following snippet reproduces the anomaly.
import jax
import jax.numpy as jnp
def f(y):
"""Some arbitrary function R^2 -> R^2."""
return jnp.stack([jnp.dot(y, y), jnp.sin(jnp.dot(y, y))])
#
# All of the functions below compute the recursion
#
# F_{n+1}(-) = (J F_n)(-) @ f(-)
#
# starting at F_0 = f.
#
def jacobian_vector_product(f, f0):
def next_iterate(x):
_, y = jax.jvp(f, primals=(x,), tangents=(f0(x),))
return y
return next_iterate
# does not work as expected
def vector_jacobian_product(f, f0):
def next_iterate(x):
_, dfx_fun = jax.vjp(f, x)
return dfx_fun(f0(x))[0]
return next_iterate
def jacobian_forward(f, f0):
def next_iterate(x):
return jax.jacfwd(f)(x) @ f0(x)
return next_iterate
def jacobian_reverse(f, f0):
def next_iterate(x):
return jax.jacrev(f)(x) @ f0(x)
return next_iterate
# Check that they do the same
y0 = jnp.arange(1.0, 3.0)
F, f0 = f, f
for _ in range(4):
F = jacobian_vector_product(F, f0)
print(F(y0))
# [6.164303 1.7485797]
# [71.161995 56.623775]
# [ 543.6876 1349.717 ]
# [ 9144.678 24030.408]
# does not work as expected
F, f0 = f, f
for _ in range(4):
F = vector_jacobian_product(F, f0)
print(F(y0))
# [ 9.455978 18.911957]
# [49.263916 -5.0995216]
# [214.84796 386.2104 ]
# [6731.6987 9059.133 ]
F, f0 = f, f
for _ in range(4):
F = jacobian_forward(F, f0)
print(F(y0))
# [6.164303 1.7485797]
# [71.162 56.623787]
# [ 543.6877 1349.7167]
# [ 9144.677 24030.408]
F, f0 = f, f
for _ in range(4):
F = jacobian_reverse(F, f0)
print(F(y0))
# [6.164303 1.7485797]
# [71.162 56.623783]
# [ 543.68774 1349.7172 ]
# [ 9144.677 24030.41 ]
I apologise for a rather lengthy example, but I figured it is useful to compare all of those results.
Python: 3.8.10. Pip freeze:
absl-py==1.0.0
click==8.0.3
flatbuffers==2.0
jax==0.2.28
jaxlib==0.1.76
mccabe==0.6.1
mypy-extensions==0.4.3
numpy==1.22.2
opt-einsum==3.3.0
pathspec==0.9.0
platformdirs==2.5.0
pycodestyle==2.8.0
pyflakes==2.4.0
scipy==1.8.0
six==1.16.0
tomli==2.0.1
typing-extensions==4.0.1
Thanks a lot for your help! 😃
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (4 by maintainers)
@pnkraemer In JAX,
vjp
is implemented in terms ofjvp
^1:What you do in the updated version of
vector_jacobian_product
is essentially usingjvp
asActually, if you take a slightly simpler
f
,vector_jacobian_product
generates the same code as yourjacobian_vector_product
. You could try the following code to see for yourself.I don’t see why
J[f](x) @ f(x)
would be equal tof(x) @ J[f](x)
(thevjp
case).