More flexible ODE integration
See original GitHub issueThe current implementation of Runge-Kutta with adjoint reverse-mode gradients is great, but there are a few things I still find myself missing, and I’d really love to help contribute, or just see in JAX one way or another.
- Auxiliary solver output. Number of function evaluations, location of time steps, diagnostic information etc. Such output is useful both from the forward and adjoint solves. One especially interesting metric for me at the moment is the difference between the initial
y0
and the “hopefully close” y(t_0) from backtracking through the dynamics in the adjoint ODE. - Differentiating through the solver, aka lame gradients. I know that this is often inferior to solving the reverse-time adjoint ODE, but for the sake of comparison it’s an essential baseline. I’m assuming this shouldn’t be too hard?
- Alternate solver choices. The scipy library does a good job exposing multiple solver options to the user. I’m not sure that their API is the cleanest approach but having the ability to choose between types of solvers would be great. Being locked-in to RK can be annoying. I envision a design that includes a set of raw solvers defined with a unified API which can then be bundled up into a usable
odeint
function with a vjp rule. Being able to select different solvers for the forward and adjoint passes would also be useful. Ideal solution would make arbitrary solver combos a cinch, eg. run RK on the forward pass, but Euler integration for the adjoint.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:11
- Comments:29 (24 by maintainers)
Top Results From Across the Web
Credit Flexibility | Ohio Department of Education
Integrated Coursework and Awarding Simultaneous Credit Guidance for Schools. State law allows districts, schools, community schools and ...
Read more >Recurrent Neural Networks in the Eye of Differential Equations
... more flexible top-down designs of new RNN architectures using large varieties of toolboxes from numerical integration of ODEs.
Read more >Theano Op using JAX for lightning-fast ODE inference
I've written a Theano Op that uses JAX to solve and autodifferentiate a system of ODEs, which allows parameter estimation via ADVI that's ......
Read more >Generalizations of the 'Linear Chain Trick' - Springer Link
... more flexible dwell time distributions into mean field ODE models ... Differentiating Eq. (36) with j=1 using the Liebniz integration ...
Read more >Fitting flexible spline using ODEs - Cross Validated
A more rigorous inferential approach is discussed here (doi: ... I apologize if I did not work on your exact ODE system (but...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
For those watching this thread – check out Diffrax, which is a library doing pretty much everything discussed here. Other RK solvers, implicit solvers, discretise-then-optimise, etc.
btw, I have a rough implementation of some other RK solvers here, in case anyone has a use for them. 😃