question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Custom_jvp of an odeint function

See original GitHub issue

Hi,

I’d like to use an odeint function to solve optimization problems, and I found an exemple of a runge-kutta 44 one at https://implicit-layers-tutorial.org/implicit_functions/#differentiation-of-functions-defined-by-ordinary-differential-equations-odes, with its custom derivatives, which interests me a lot for solving optimization problems with good accuracy. However, it produces an error that I can’t solve. Here is my code (with the odeint_rk4 function found in the previous link) :

from jax import custom_jvp,jvp
import jax.numpy as np
from jax.lax import scan
from noloadj.optimization.optimProblem import OptimProblem,Spec
import math

def odeint_rk4(f, y0, t, *args):
  def step(state, t):
    y_prev, t_prev = state
    h = t - t_prev
    k1 = h * f(y_prev, t_prev, *args)
    k2 = h * f(y_prev + k1/2., t_prev + h/2., *args)
    k3 = h * f(y_prev + k2/2., t_prev + h/2., *args)
    k4 = h * f(y_prev + k3, t + h, *args)
    y = y_prev + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return (y, t), y
  _, ys = scan(step, (y0, t[0]), t[1:])
  return ys

odeint_rk4 = custom_jvp(odeint_rk4, nondiff_argnums=(0,))

@odeint_rk4.defjvp
def odeint_rk4_jvp(f, primals, tangents):
  y0, t, *args = primals
  delta_y0, _, *delta_args = tangents
  nargs = len(args)

  def f_aug(aug_state, t, *args_and_delta_args):
    primal_state, tangent_state = aug_state
    args, delta_args = args_and_delta_args[:nargs], args_and_delta_args[nargs:]
    primal_dot, tangent_dot = jvp(f, (primal_state, t, *args), (tangent_state, 0., *delta_args))
    return np.stack([primal_dot, tangent_dot])

  aug_init_state = np.stack([y0, delta_y0])
  aug_states = odeint_rk4(f_aug, aug_init_state, t, *args, *delta_args)
  ys, ys_dot = aug_states[:, 0, :], aug_states[:, 1, :]
  return ys, ys_dot


def lancer(m,R,v0,a, x0, y0, tf):
    g = 9.81
    k=0.5*1.292*0.5*math.pi*R*R # k=0.5*rho(air)*Cx*Section_sphere pour f=-kv²
    # frottements  https://www.physagreg.fr/mecanique-12-chute-frottements.php
    vx0,vy0= v0*np.cos(a),v0*np.sin(a)
    sol = odeint_rk4(lambda s,t : np.array([s[2],s[3],-k*s[2]*(s[2]*s[2]+s[3]*s[3])**0.5/m, -k*s[3]*(s[2]*s[2]+s[3]*s[3])**0.5/m-g]),
                 np.array([x0, y0, vx0, vy0]),np.linspace(0,tf,int(tf/1e-3)))
    s=np.transpose(sol)
    x,y=s[0],s[1]
    yf,xf=y[-1],x[-1]
    hauteur=np.max(y)
    return locals().items()

spec=Spec(variables={'m':1.0,'R':0.2}, bounds={'m':[0.5,5.],'R':[0.01,1.0]},
          objectives=['hauteur'], eq_cstr={'yf':0.0,'xf':22.0},freeOutputs=['x','y'])

parameters={'x0':0.,'y0':2.,'tf':3,'v0':10,'a':math.pi/4}
optim=OptimProblem(model=lancer,specifications=spec,parameters=parameters)
result=optim.run()
result.printResults()

It produces the error :

jax.interpreters.ad.CustomJVPException:
Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn’t supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters.
Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule. 

Thanks in advance for helping me.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:8 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
patrick-kidgercommented, Apr 4, 2022

Alternatively:

  • Option 1: just use jax.jvp applied to the original odeint_rk4 directly; don’t use a custom JVP at all. Numerically speaking this direct method (called “discretise-then-optimise”) is nearly always more accurate / what is actually desired; correspondingly in nearly every case this alternate “continuous time JVP” should be avoided.
  • Option 2: don’t write your own integrator, and just use a library like Diffrax instead. (Which doesn’t actually provide RK4 because other solvers are usually better. I’d recommend using diffrax.Tsit5 instead.)
0reactions
YouJiachengcommented, Apr 11, 2022

Great!

Read more comments on GitHub >

github_iconTop Results From Across the Web

scipy.integrate.odeint — SciPy v1.9.3 Manual
To use a function with the signature func(t, y, ...) , the argument tfirst must be set to True . Parameters. funccallable(y, t,...
Read more >
Custom derivative rules for JAX-transformable Python functions
Say we want to write a function called log1pexp , which computes x ↦ log ⁡ ( 1 + e x ) ....
Read more >
Solve Differential Equations with ODEINT
Differential equations are solved in Python with the Scipy.integrate package using function odeint or solve_ivp. Jupyter Notebook ODEINT Examples on GitHub.
Read more >
Solve Differential Equations with ODEINT Function of SciPy ...
In this post, we are going to learn how to solve differential equations with odeint function of scipy module in Python.
Read more >
Solve Differential Equations in Python by Using odeint() SciPy ...
We use the SciPy Python function odeint (). We explain the solution method by using a differential equation describing the dynamics of a ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found