Time-efficient higher-order forward mode?
See original GitHub issueHave you guys given any thought to how to efficiently compute higher-order derivatives of scalar-input functions? I have a use for 4th- and 5th-order derivatives of a scalar-input, vector-output function, namely, regularizing ODEs to be easy to solve.
I’m not sure, but I think that in principle, the Nth-order derivative of a scalar-input function can be computed for about only N times the cost of evaluating the original function, if intermediate values are cached. We think we have a partial solution using https://github.com/JuliaDiff/TaylorSeries.jl, but I’d rather do it in JAX.
The following toy example takes time exponential in the order of the derivative:
from jax import jvp
import jax.numpy as np
def fwd_deriv(f):
def df(x):
return jvp(f, (x,), (1.0,))[1]
return df
def f(t):
return 0.3 * np.sin(t) * t**10
g = f
for i in range(10):
g = fwd_deriv(g)
print(g(1.1))
Is there a simple way to set things up with jvps and vjps to be more time-efficient, or do you think it would require a different type of autodiff entirely?
Issue Analytics
- State:
- Created 5 years ago
- Reactions:2
- Comments:18 (18 by maintainers)
This was definitely a bit of a rabbit hole we stumbled into!
We discussed this a bit in our chat. We think the answer is to add a CSE pass that happens after every level of differentiation. But we might also need some other simplifications, like collecting terms
x + x + x = 3x
.CSE is easy enough to add. We’ll try it out and report back!