Can't return solution of coupled differential equations
See original GitHub issueI’m trying to solve a mid-sized system of coupled differential equations with diffrax
. I’m using version 0.2.0. Here’s a short snippet of dummy code that raises the issue I’m having:
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Kvaerno3,PIDController
def Results():
def Y_prime(t, Y, args):
dY = jnp.array([Y[6], (Y[5]-Y[6])**2,Y[0]+Y[7], (Y[1])**2, Y[2],Y[3], Y[4]**3, Y[5]**2])
return dY
t_init = 100
t_fin = 1e5
Yn_i = 1e-5
Yp_i = 1e-6
Yd_i = 1e-12
Yt_i = 1e-12
YHe3_i = 1e-12
Ya_i = 1e-12
YLi7_i = 1e-12
YBe7_i = 1e-12
Y0=jnp.array([[Yn_i], [Yp_i], [Yd_i], [Yt_i], [YHe3_i], [Ya_i], [YLi7_i], [YBe7_i]])
term = ODETerm(Y_prime)
solver = Kvaerno3()
stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
t_eval = jnp.logspace(jnp.log10(t_init),jnp.log10(t_fin),num=100)
sol_at_MT = diffeqsolve(term, solver, t0=jnp.float64(t_init), t1=jnp.float64(t_fin), dt0=jnp.float64((t_eval[1]-t_eval[0])/10),y0=Y0,stepsize_controller=stepsize_controller,max_steps=None)
Yn_MT_f, Yp_MT_f, Yd_MT_f, Yt_MT_f, YHe3_MT_f, Ya_MT_f, YLi7_MT_f, YBe7_MT_f = sol_at_MT.ys[-1][0][0],sol_at_MT.ys[-1][1][0],sol_at_MT.ys[-1][2][0],sol_at_MT.ys[-1][3][0],sol_at_MT.ys[-1][4][0],sol_at_MT.ys[-1][5][0],sol_at_MT.ys[-1][6][0],sol_at_MT.ys[-1][7][0]
Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Yn_MT_f, Yp_MT_f, Yd_MT_f,Yt_MT_f,YHe3_MT_f,Ya_MT_f,YLi7_MT_f, YBe7_MT_f
return jnp.array([Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f])
Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Results()
print(Yn_f)
It seems diffrax
successfully solves the differential equation, but struggles to return the output, i.e. it seems the code hangs when trying to assign values to the variable sol_at_MT
. Tampering a bit with the diffrax
source, it looks like there are two things going on.
One is that, no matter what I try to return (even if I set all of the returns to None
), if the lines right before the return
in integrate.py
branched_error_if(
throw & jnp.invert(is_okay(result)),
error_index,
RESULTS.reverse_lookup,
)
aren’t commented out, the code will freeze. I can include a print statement right after these lines (just before the return
) that prints out successfully even when they’re not commented, but I can’t assign anything to sol_at_MT
in without the code hanging if these lines are left in.
Then, if I comment that branched_error_if()
call out, the code still hangs if I try to return ts
, ys
, stats
or result
from integrate.py
. This doesn’t seem to be an issue of time or memory; the code just freezes up and can’t even be aborted from the command line whether I’m running locally or with extra resources on a cluster.
Issue Analytics
- State:
- Created a year ago
- Comments:12 (6 by maintainers)
Top GitHub Comments
Sorry it’s taken so long for me to get back to you, just wanted to be sure this was also helping with the larger set of DEs I’m trying to solve. This is helpful, thanks!
It turns out that what is going on here is this:
dt0
that is being passed intodiffrax.diffeqsolve
is very large. When the solver makes its first numerical step, using this step size, we end up in a nonphysical region of space for whichy[2]
is negative.jnp.sqrt(y[2])
. Andjnp.sqrt(something negative)
produces aNaN
.NaN
, which it doesn’t know how to handle, and the whole thing chokes.As a general rule, you shouldn’t use vector fields that are capable of returning
NaN
s: theNaN
can propagate and break things in strange ways, as it has done here. (In this case, both thesqrt
and the/
are pretty suspicious.)It’s actually pretty hard to come up with a sensible policy for handling
NaN
s in all the various cases they might arise. Still, at least for this case it’s clear that we want to reject the step – I’ll submit a PR soon so thatdiffrax.PIDController
rejects steps in the event of aNaN
.For posterity, a couple of other solutions to this problem (right now, even without this upcoming PR):
dt0=None
ordt0=something small
: in this case we never end up in an unphysical region of space, things work as normal, and both Diffrax andsolve_ivp
obtain identical solutions.NaN
. In your case,out = jnp.stack(...); return jnp.where(jnp.isnan(out), jnp.inf, out)
will suffice.