question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Slow `jit` compilation time compared to `jax.experimental.ode.odeint`

See original GitHub issue

hi @patrick-kidger, big fan of diffrax!

I’ve been playing around with some of the functionality you have in this repository and comparing it with the ode-solver in jax. The one pain point i noticed is that there seems to be a relatively slow jit compilation time, particularly when I try to jit the grad of a simple loss function containing diffeqsolve. I was wondering if this is an error on my part (perhaps I botched the diffrax implementation) or if there is yet to be some optimization. The demonstration is below:

from jax.config import config
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", True) 
config.parse_flags_with_absl()
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from functools import partial
import haiku as hk

def exact_kinematic_aug_diff_f(t, y, args_tuple):
    """
    """
    _y, _, _ = y
    _params, _key, diff_f = args_tuple
    aug_diff_fn = lambda __y : diff_f(t, __y, (_params,))
    _f, s, t = aug_diff_fn(_y)
    r = jnp.sum(t)
    return _f, r, 0.

def exact_kinematic_odeint_diff_f(y, t, params, canonical_diff_fn):
    run_y = y[0]
    _f, s, t = canonical_diff_fn(t, run_y, (params,))
    return _f, jnp.sum(s), 0.

class TestMLP(hk.Module):
    def __init__(self, num_particles, name=None):
        super().__init__(name=None)
        self._mlp = hk.nets.MLP([8,8,8,8,num_particles*12])
        self._num_particles=num_particles
    def __call__(self, t, y):
        in_y = (y + t).flatten()
        outter = self._mlp(in_y).reshape((4, self._num_particles, 3))
        return outter[:2], outter[2], outter[3]

def test(num_particles):
    import functools
    from jax.experimental.ode import odeint
    import diffrax
    
    #generate positions/velocities
    small_positions = jax.random.normal(jax.random.PRNGKey(261), shape=(num_particles,3))
    small_velocities = jax.random.normal(jax.random.PRNGKey(235), shape=(num_particles,3))
    small_positions_and_velocities = jnp.vstack([small_positions[jnp.newaxis, ...], small_velocities[jnp.newaxis, ...]])
    
    # make module kwargs
    VectorMLP_kwargs = {'num_particles': num_particles}
    
    # make module function
    def _diff_f_wrapper(t, y):
        diff_f = TestMLP(**VectorMLP_kwargs)
        return diff_f(t, y)
    
    diff_f_init, diff_f_apply = hk.without_apply_rng(hk.transform(_diff_f_wrapper))
    init_params = diff_f_init(jax.random.PRNGKey(36), 0., small_positions_and_velocities)
    canonicalized_diff_f_fn = lambda _t, _y, _args_tuple : diff_f_apply(_args_tuple[0], _t, _y)
    
    # make the augmented functions
    odeint_aug_diff_func = functools.partial(exact_kinematic_odeint_diff_f, canonical_diff_fn=canonicalized_diff_f_fn)
    diffeqsolve_aug_diff_func = exact_kinematic_aug_diff_f
    
    # odeint solver
    def odeint_solver(_parameters, _init_y, _key):
        aug_init_y = (_init_y, 0., 0.)
        outs = odeint(odeint_aug_diff_func, aug_init_y, jnp.array([0., 1.]), _parameters, rtol=1.4e-8, atol=1.4e-8)
        final_outs = (outs[0][-1], outs[1][-1], outs[2][-1])
        return final_outs
    
    def diffrax_ode_solver(_parameters, _init_y, _key):
        term=diffrax.ODETerm(diffeqsolve_aug_diff_func)
        stepsize_controller=diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
        solver = diffrax.Dopri5()
        aug_init_y = (_init_y, 0., 0.)
        sol = diffrax.diffeqsolve(term, 
                                  solver, 
                                  t0=0., 
                                  t1=1., 
                                  dt0=1e-1, 
                                  y0=aug_init_y, 
                                  stepsize_controller=stepsize_controller, 
                                  args=(_parameters, _key, canonicalized_diff_f_fn))
        return sol.ys[0][0], sol.ys[1][0], sol.ys[2][0]
    
    @jax.jit
    def odeint_loss_fn(_params, _init_y, _key):
        ode_solution = odeint_solver(_params, _init_y, _key)
        return jnp.sum(ode_solution[1]**2)
    
    @jax.jit
    def diffrax_loss_fn(_params, _init_y, _key):
        ode_solution = diffrax_ode_solver(_params, _init_y, _key)
        return jnp.sum((ode_solution[1])**2)
    
    # test
    import time
    
    # odeint compilation time
    start_time = time.time()
    _ = jax.grad(odeint_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
    end_time = time.time()
    print(f"odeint comp. time: {end_time - start_time}")
    
    # diffrax compilation time
    start_time = time.time()
    _ = jax.grad(diffrax_loss_fn)(init_params, small_positions_and_velocities, jax.random.PRNGKey(34))
    end_time = time.time()
    print(f"diffrax comp. time: {end_time - start_time}")

running test(8) gives me the following compilation time on CPU:

odeint comp. time: 2.5580570697784424
diffrax comp. time: 23.965799570083618

I noticed that if I use diffrax.BacksolveAdjoint, compilation time goes down to ~8 seconds, but I’m keen to avoid that method based on your docs.; also, it looks like the compilation time in diffrax is heavily dependent on the number of hidden layers in TestMLP, perhaps suggesting a non-optimal compilation in diffrax of for loops? Thanks!

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:11 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
patrick-kidgercommented, Jul 5, 2022

This issue should now be resolved in #124. TL;DR we should now be as fast as jax.experimental.ode at compilation.

Thanks to everyone who’s been involved in this issue. If you can I’d invite you to try out this branch some time in the next few days. It’d be great to:

  1. Verify that this resolves things for you (and not just on my benchmarks);
  2. Get some more real-world performance numbers for compiletime/runtime, with/without scan_stages, as I haven’t yet decided which way to set the default value.
  3. To double-check that everything seems to be behaving as you expect (i.e. you are in fact obtaining the numerical solution to a differential equation), as this is a pretty substantial rewrite of solver internals.
0reactions
druidowmcommented, Jun 3, 2022

Hi Patrick, thanks! I will try replacing the for loop, and if it works, I’ll add a PR.

Read more comments on GitHub >

github_iconTop Results From Across the Web

FAQ - Diffrax - Patrick Kidger
Compilation is taking a long time. The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour. How...
Read more >
Efficiently sampling a large ODE model (compiling issues?)
For a solution, you can try to use jax.vmap with your odeint to make it faster. When compiling time is > 1 minute,...
Read more >
Change log - JAX documentation
Ahead-of-time lowering and compilation functionality (tracked in #7733) is stable and public. See the overview and the API docs for jax.stages .
Read more >
Recently Active 'jax' Questions - Page 3 - Stack Overflow
I want to filter a jnp.array with a condition, and accumulate to a global variable, in a jit function (so we have to...
Read more >
Why Numba and Cython are not substitutes for Julia
In it we show how JIT compiling a function with Numba only moderately helps the ODE solver (i.e. 50% performance gain). import numpy...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found