question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

forward-mode autodiff for odeint

See original GitHub issue

Hi Jax team,

We want to calculate hessians of a likelihood function involving an ode integration so that we can do variational inference. We are running into an issue with custom_vjp, which we don’t understand how to fix. We have the impression that it is not implemented for odeint. Our package is called ticktack, which is distributed on PyPI. The dataset miyake12.csv is hosted on GitHub here. Do you have any advice? Can we implement this easily, or are there plans to do this for odeint?

A minimal example:

import ticktack
from ticktack import fitting

cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
cf = fitting.CarbonFitter(cbm)
default_params = [775., 1./12, np.pi/2., 81./12]
cf.load_data('miyake12.csv')

g = jit(hessian(cf.log_prob))
g(default_params) 

We are getting the output,

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_31044/2345364585.py in <module>
----> 1 g(default_params)

    [... skipping hidden 49 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in log_prob(self, params)
    107         # call log_like and log_prior, for later MCMC
    108         lp = self.log_prior(params)
--> 109         pos = self.log_like(params)
    110         return lp + pos
    111 

    [... skipping hidden 25 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in log_like(self, params)
     92     def log_like(self, params):
     93         # calls dc14 and compare to data, (can be gp or gaussian loglikelihood)
---> 94         d_14_c = self.dc14(params)
     95 
     96         chi2 = jnp.sum(((self.d14c_data[:-1] - d_14_c)/self.d14c_data_error[:-1])**2)

    [... skipping hidden 25 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in dc14(self, params)
     78     def dc14(self, params):
     79     # calls CBM on production_rate of params
---> 80         burn_in = self.run(self.burn_in_time, params, self.steady_state_y0)
     81         d_14_c = self.run_D_14_C_values(self.time_data, self.time_oversample, params, burn_in[-1, :])
     82         return d_14_c - 22.72

    [... skipping hidden 25 frame]

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in run(self, time_values, params, y0)
     65     @partial(jit, static_argnums=(0,))
     66     def run(self, time_values, params, y0):
---> 67         burn_in, _ = self.cbm.run(time_values, production=self.miyake_event, args=params, y0=y0)
     68         return burn_in
     69 

/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/ticktack.py in run(self, time_values, production, y0, args, target_C_14, steady_state_production)
    323 
    324         if USE_JAX:
--> 325             states = odeint(derivative, y_initial, time_values)
    326         else:
    327             states = odeint(derivative, y_initial, time_values)

~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
    171 
    172   converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
--> 173   return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
    174 
    175 @partial(jax.jit, static_argnums=(0, 1, 2, 3))

    [... skipping hidden 25 frame]

~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args)
    177   y0, unravel = ravel_pytree(y0)
    178   func = ravel_first_arg(func, unravel)
--> 179   out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    180   return jax.vmap(unravel)(out)
    181 

    [... skipping hidden 4 frame]

~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args)
    216 
    217 def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
--> 218   ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
    219   return ys, (ys, ts, args)
    220 

    [... skipping hidden 5 frame]

~/.local/lib/python3.8/site-packages/jax/interpreters/ad.py in _raise_custom_vjp_error_on_jvp(*_, **__)
    676 
    677 def _raise_custom_vjp_error_on_jvp(*_, **__):
--> 678   raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
    679                   "function.")
    680 custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
benjaminpopecommented, Jul 29, 2021

Cool, awesome. I think for now we’ll give it a go with jacrev but love to know if odeint gets an upgrade. I’m really excited by the stuff you’re doing.

0reactions
dumanahcommented, May 17, 2022

For those who is interested in this problem, i would suggest to have a look at diffrax. Not only forward-mode differentiation can be applied its ode solver, it also has many different type of ode solver’s as euler, heun, dopri5, runga-kutta etc.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Continuous-Time Meta-Learning with Forward Mode ...
introduced a novel practical algorithm based on forward mode automatic differentiation, capable of efficiently computing the exact meta-gradients using an ...
Read more >
Forward-Mode Automatic Differentiation (AD) via ... - YouTube
In Fall 2020 and Spring 2021, this was MIT's 18.337J/6.338J: Parallel Computing and Scientific Machine Learning course.
Read more >
Continuous-Time Meta-Learning with Forward Mode ... - arXiv
In order to compute this meta-gradient using forward-mode automatic differentiation, we want to first compute the Jacobian matrix dW(t)/dW0. We ...
Read more >
Custom derivative rules for JAX-transformable Python functions
We associate with it two functions, f_fwd and f_bwd , which describe how to perform the forward- and backward-passes of reverse-mode autodiff, ...
Read more >
Chapter 2: Implicit functions and automatic differentiation
In programs, JVPs underlie forward-mode automatic differentiation, in the sense that ... Thus odeint defines an implicit function from its input data (the ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found