forward-mode autodiff for odeint
See original GitHub issueHi Jax team,
We want to calculate hessians of a likelihood function involving an ode integration so that we can do variational inference. We are running into an issue with custom_vjp
, which we don’t understand how to fix. We have the impression that it is not implemented for odeint
. Our package is called ticktack, which is distributed on PyPI. The dataset miyake12.csv is hosted on GitHub here. Do you have any advice? Can we implement this easily, or are there plans to do this for odeint
?
A minimal example:
import ticktack
from ticktack import fitting
cbm = ticktack.load_presaved_model('Guttler14', production_rate_units = 'atoms/cm^2/s')
cf = fitting.CarbonFitter(cbm)
default_params = [775., 1./12, np.pi/2., 81./12]
cf.load_data('miyake12.csv')
g = jit(hessian(cf.log_prob))
g(default_params)
We are getting the output,
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_31044/2345364585.py in <module>
----> 1 g(default_params)
[... skipping hidden 49 frame]
/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in log_prob(self, params)
107 # call log_like and log_prior, for later MCMC
108 lp = self.log_prior(params)
--> 109 pos = self.log_like(params)
110 return lp + pos
111
[... skipping hidden 25 frame]
/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in log_like(self, params)
92 def log_like(self, params):
93 # calls dc14 and compare to data, (can be gp or gaussian loglikelihood)
---> 94 d_14_c = self.dc14(params)
95
96 chi2 = jnp.sum(((self.d14c_data[:-1] - d_14_c)/self.d14c_data_error[:-1])**2)
[... skipping hidden 25 frame]
/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in dc14(self, params)
78 def dc14(self, params):
79 # calls CBM on production_rate of params
---> 80 burn_in = self.run(self.burn_in_time, params, self.steady_state_y0)
81 d_14_c = self.run_D_14_C_values(self.time_data, self.time_oversample, params, burn_in[-1, :])
82 return d_14_c - 22.72
[... skipping hidden 25 frame]
/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/fitting.py in run(self, time_values, params, y0)
65 @partial(jit, static_argnums=(0,))
66 def run(self, time_values, params, y0):
---> 67 burn_in, _ = self.cbm.run(time_values, production=self.miyake_event, args=params, y0=y0)
68 return burn_in
69
/usr/local/lib/python3.8/dist-packages/ticktack-0.1.2.0-py3.8.egg/ticktack/ticktack.py in run(self, time_values, production, y0, args, target_C_14, steady_state_production)
323
324 if USE_JAX:
--> 325 states = odeint(derivative, y_initial, time_values)
326 else:
327 states = odeint(derivative, y_initial, time_values)
~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
171
172 converted, consts = custom_derivatives.closure_convert(func, y0, t[0], *args)
--> 173 return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
174
175 @partial(jax.jit, static_argnums=(0, 1, 2, 3))
[... skipping hidden 25 frame]
~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in _odeint_wrapper(func, rtol, atol, mxstep, y0, ts, *args)
177 y0, unravel = ravel_pytree(y0)
178 func = ravel_first_arg(func, unravel)
--> 179 out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
180 return jax.vmap(unravel)(out)
181
[... skipping hidden 4 frame]
~/.local/lib/python3.8/site-packages/jax/experimental/ode.py in _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args)
216
217 def _odeint_fwd(func, rtol, atol, mxstep, y0, ts, *args):
--> 218 ys = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
219 return ys, (ys, ts, args)
220
[... skipping hidden 5 frame]
~/.local/lib/python3.8/site-packages/jax/interpreters/ad.py in _raise_custom_vjp_error_on_jvp(*_, **__)
676
677 def _raise_custom_vjp_error_on_jvp(*_, **__):
--> 678 raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
679 "function.")
680 custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Continuous-Time Meta-Learning with Forward Mode ...
introduced a novel practical algorithm based on forward mode automatic differentiation, capable of efficiently computing the exact meta-gradients using an ...
Read more >Forward-Mode Automatic Differentiation (AD) via ... - YouTube
In Fall 2020 and Spring 2021, this was MIT's 18.337J/6.338J: Parallel Computing and Scientific Machine Learning course.
Read more >Continuous-Time Meta-Learning with Forward Mode ... - arXiv
In order to compute this meta-gradient using forward-mode automatic differentiation, we want to first compute the Jacobian matrix dW(t)/dW0. We ...
Read more >Custom derivative rules for JAX-transformable Python functions
We associate with it two functions, f_fwd and f_bwd , which describe how to perform the forward- and backward-passes of reverse-mode autodiff, ...
Read more >Chapter 2: Implicit functions and automatic differentiation
In programs, JVPs underlie forward-mode automatic differentiation, in the sense that ... Thus odeint defines an implicit function from its input data (the ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Cool, awesome. I think for now we’ll give it a go with
jacrev
but love to know ifodeint
gets an upgrade. I’m really excited by the stuff you’re doing.For those who is interested in this problem, i would suggest to have a look at diffrax. Not only forward-mode differentiation can be applied its ode solver, it also has many different type of ode solver’s as euler, heun, dopri5, runga-kutta etc.