Slow `jit` compilation time compared to `jax.experimental.ode.odeint`
See original GitHub issuehi @patrick-kidger,
big fan of diffrax
!
I’ve been playing around with some of the functionality you have in this repository and comparing it with the ode-solver in jax
. The one pain point i noticed is that there seems to be a relatively slow jit
compilation time, particularly when I try to jit
the grad
of a simple loss function containing diffeqsolve
. I was wondering if this is an error on my part (perhaps I botched the diffrax
implementation) or if there is yet to be some optimization. The demonstration is below:
from jax.config import config
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True)
config.parse_flags_with_absl()
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from functools import partial
import haiku as hk
def exact_kinematic_aug_diff_f(t, y, args_tuple):
"""
"""
_y, _, _ = y
_params, _key, diff_f = args_tuple
aug_diff_fn = lambda __y : diff_f(t, __y, (_params,))
_f, s, t = aug_diff_fn(_y)
r = jnp.sum(t)
return _f, r, 0.
def exact_kinematic_odeint_diff_f(y, t, params, canonical_diff_fn):
run_y = y[0]
_f, s, t = canonical_diff_fn(t, run_y, (params,))
return _f, jnp.sum(s), 0.
class TestMLP(hk.Module):
def __init__(self, num_particles, name=None):
super().__init__(name=None)
self._mlp = hk.nets.MLP([8,8,8,8,num_particles*12])
self._num_particles=num_particles
def __call__(self, t, y):
in_y = (y + t).flatten()
outter = self._mlp(in_y).reshape((4, self._num_particles, 3))
return outter[:2], outter[2], outter[3]
def test(num_particles):
import functools
from jax.experimental.ode import odeint
import diffrax
#generate positions/velocities
small_positions = jax.random.normal(jax.random.PRNGKey(261), shape=(num_particles,3))
small_velocities = jax.random.normal(jax.random.PRNGKey(235), shape=(num_particles,3))
small_positions_and_velocities = jnp.vstack([small_positions[jnp.newaxis, ...], small_velocities[jnp.newaxis, ...]])
# make module kwargs
VectorMLP_kwargs = {'num_particles': num_particles}
# make module function
def _diff_f_wrapper(t, y):
diff_f = TestMLP(**VectorMLP_kwargs)
return diff_f(t, y)
diff_f_init, diff_f_apply = hk.without_apply_rng(hk.transform(_diff_f_wrapper))
init_params = diff_f_init(jax.random.PRNGKey(36), 0., small_positions_and_velocities)
canonicalized_diff_f_fn = lambda _t, _y, _args_tuple : diff_f_apply(_args_tuple[0], _t, _y)
# make the augmented functions
odeint_aug_diff_func = functools.partial(exact_kinematic_odeint_diff_f, canonical_diff_fn=canonicalized_diff_f_fn)
diffeqsolve_aug_diff_func = exact_kinematic_aug_diff_f
# odeint solver
def odeint_solver(_parameters, _init_y, _key):
aug_init_y = (_init_y, 0., 0.)
outs = odeint(odeint_aug_diff_func, aug_init_y, jnp.array([0., 1.]), _parameters, rtol=1.4e-8, atol=1.4e-8)
final_outs = (outs[0][-1], outs[1][-1], outs[2][-1])
return final_outs
def diffrax_ode_solver(_parameters, _init_y, _key):
term=diffrax.ODETerm(diffeqsolve_aug_diff_func)
stepsize_controller=diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
solver = diffrax.Dopri5()
aug_init_y = (_init_y, 0., 0.)
sol = diffrax.diffeqsolve(term,
solver,
t0=0.,
t1=1.,
dt0=1e-1,
y0=aug_init_y,
stepsize_controller=stepsize_controller,
args=(_parameters, _key, canonicalized_diff_f_fn))
return sol.ys[0][0], sol.ys[1][0], sol.ys[2][0]
@jax.jit
def odeint_loss_fn(_params, _init_y, _key):
ode_solution = odeint_solver(_params, _init_y, _key)
return jnp.sum(ode_solution[1]**2)
@jax.jit
def diffrax_loss_fn(_params, _init_y, _key):
ode_solution = diffrax_ode_solver(_params, _init_y, _key)
return jnp.sum((ode_solution[1])**2)
# test
import time
# odeint compilation time
start_time = time.time()
_ = jax.grad(odeint_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
end_time = time.time()
print(f"odeint comp. time: {end_time - start_time}")
# diffrax compilation time
start_time = time.time()
_ = jax.grad(diffrax_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
end_time = time.time()
print(f"diffrax comp. time: {end_time - start_time}")
running test(8)
gives me the following compilation time on CPU:
odeint comp. time: 2.5580570697784424
diffrax comp. time: 23.965799570083618
I noticed that if I use diffrax.BacksolveAdjoint
, compilation time goes down to ~8 seconds, but I’m keen to avoid that method based on your docs.; also, it looks like the compilation time in diffrax
is heavily dependent on the number of hidden layers in TestMLP
, perhaps suggesting a non-optimal compilation in diffrax
of for
loops?
Thanks!
Issue Analytics
- State:
- Created a year ago
- Comments:11 (5 by maintainers)
Top GitHub Comments
This issue should now be resolved in #124. TL;DR we should now be as fast as
jax.experimental.ode
at compilation.Thanks to everyone who’s been involved in this issue. If you can I’d invite you to try out this branch some time in the next few days. It’d be great to:
scan_stages
, as I haven’t yet decided which way to set the default value.Hi Patrick, thanks! I will try replacing the for loop, and if it works, I’ll add a PR.