Custom_jvp of an odeint function
See original GitHub issueHi,
I’d like to use an odeint function to solve optimization problems, and I found an exemple of a runge-kutta 44 one at https://implicit-layers-tutorial.org/implicit_functions/#differentiation-of-functions-defined-by-ordinary-differential-equations-odes, with its custom derivatives, which interests me a lot for solving optimization problems with good accuracy. However, it produces an error that I can’t solve. Here is my code (with the odeint_rk4 function found in the previous link) :
from jax import custom_jvp,jvp
import jax.numpy as np
from jax.lax import scan
from noloadj.optimization.optimProblem import OptimProblem,Spec
import math
def odeint_rk4(f, y0, t, *args):
def step(state, t):
y_prev, t_prev = state
h = t - t_prev
k1 = h * f(y_prev, t_prev, *args)
k2 = h * f(y_prev + k1/2., t_prev + h/2., *args)
k3 = h * f(y_prev + k2/2., t_prev + h/2., *args)
k4 = h * f(y_prev + k3, t + h, *args)
y = y_prev + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)
return (y, t), y
_, ys = scan(step, (y0, t[0]), t[1:])
return ys
odeint_rk4 = custom_jvp(odeint_rk4, nondiff_argnums=(0,))
@odeint_rk4.defjvp
def odeint_rk4_jvp(f, primals, tangents):
y0, t, *args = primals
delta_y0, _, *delta_args = tangents
nargs = len(args)
def f_aug(aug_state, t, *args_and_delta_args):
primal_state, tangent_state = aug_state
args, delta_args = args_and_delta_args[:nargs], args_and_delta_args[nargs:]
primal_dot, tangent_dot = jvp(f, (primal_state, t, *args), (tangent_state, 0., *delta_args))
return np.stack([primal_dot, tangent_dot])
aug_init_state = np.stack([y0, delta_y0])
aug_states = odeint_rk4(f_aug, aug_init_state, t, *args, *delta_args)
ys, ys_dot = aug_states[:, 0, :], aug_states[:, 1, :]
return ys, ys_dot
def lancer(m,R,v0,a, x0, y0, tf):
g = 9.81
k=0.5*1.292*0.5*math.pi*R*R # k=0.5*rho(air)*Cx*Section_sphere pour f=-kv²
# frottements https://www.physagreg.fr/mecanique-12-chute-frottements.php
vx0,vy0= v0*np.cos(a),v0*np.sin(a)
sol = odeint_rk4(lambda s,t : np.array([s[2],s[3],-k*s[2]*(s[2]*s[2]+s[3]*s[3])**0.5/m, -k*s[3]*(s[2]*s[2]+s[3]*s[3])**0.5/m-g]),
np.array([x0, y0, vx0, vy0]),np.linspace(0,tf,int(tf/1e-3)))
s=np.transpose(sol)
x,y=s[0],s[1]
yf,xf=y[-1],x[-1]
hauteur=np.max(y)
return locals().items()
spec=Spec(variables={'m':1.0,'R':0.2}, bounds={'m':[0.5,5.],'R':[0.01,1.0]},
objectives=['hauteur'], eq_cstr={'yf':0.0,'xf':22.0},freeOutputs=['x','y'])
parameters={'x0':0.,'y0':2.,'tf':3,'v0':10,'a':math.pi/4}
optim=OptimProblem(model=lancer,specifications=spec,parameters=parameters)
result=optim.run()
result.printResults()
It produces the error :
jax.interpreters.ad.CustomJVPException:
Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn’t supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters.
Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule.
Thanks in advance for helping me.
Issue Analytics
- State:
- Created a year ago
- Comments:8 (4 by maintainers)
Top Results From Across the Web
scipy.integrate.odeint — SciPy v1.9.3 Manual
To use a function with the signature func(t, y, ...) , the argument tfirst must be set to True . Parameters. funccallable(y, t,...
Read more >Custom derivative rules for JAX-transformable Python functions
Say we want to write a function called log1pexp , which computes x ↦ log ( 1 + e x ) ....
Read more >Solve Differential Equations with ODEINT
Differential equations are solved in Python with the Scipy.integrate package using function odeint or solve_ivp. Jupyter Notebook ODEINT Examples on GitHub.
Read more >Solve Differential Equations with ODEINT Function of SciPy ...
In this post, we are going to learn how to solve differential equations with odeint function of scipy module in Python.
Read more >Solve Differential Equations in Python by Using odeint() SciPy ...
We use the SciPy Python function odeint (). We explain the solution method by using a differential equation describing the dynamics of a ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Alternatively:
jax.jvp
applied to the originalodeint_rk4
directly; don’t use a custom JVP at all. Numerically speaking this direct method (called “discretise-then-optimise”) is nearly always more accurate / what is actually desired; correspondingly in nearly every case this alternate “continuous time JVP” should be avoided.diffrax.Tsit5
instead.)Great!