question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Can't return solution of coupled differential equations

See original GitHub issue

I’m trying to solve a mid-sized system of coupled differential equations with diffrax. I’m using version 0.2.0. Here’s a short snippet of dummy code that raises the issue I’m having:

import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Kvaerno3,PIDController

def Results():
    def Y_prime(t, Y, args):
        dY = jnp.array([Y[6], (Y[5]-Y[6])**2,Y[0]+Y[7], (Y[1])**2, Y[2],Y[3], Y[4]**3, Y[5]**2])
        return dY
        
    t_init = 100
    t_fin = 1e5

    Yn_i = 1e-5
    Yp_i = 1e-6
    Yd_i = 1e-12
    Yt_i = 1e-12
    YHe3_i = 1e-12
    Ya_i = 1e-12
    YLi7_i = 1e-12
    YBe7_i = 1e-12

    Y0=jnp.array([[Yn_i], [Yp_i], [Yd_i], [Yt_i], [YHe3_i], [Ya_i], [YLi7_i], [YBe7_i]])
    term = ODETerm(Y_prime)
    solver = Kvaerno3()
    stepsize_controller = PIDController(rtol=1e-8, atol=1e-8)
    t_eval = jnp.logspace(jnp.log10(t_init),jnp.log10(t_fin),num=100)
    sol_at_MT = diffeqsolve(term, solver, t0=jnp.float64(t_init), t1=jnp.float64(t_fin), dt0=jnp.float64((t_eval[1]-t_eval[0])/10),y0=Y0,stepsize_controller=stepsize_controller,max_steps=None)
    Yn_MT_f, Yp_MT_f, Yd_MT_f, Yt_MT_f, YHe3_MT_f, Ya_MT_f, YLi7_MT_f, YBe7_MT_f = sol_at_MT.ys[-1][0][0],sol_at_MT.ys[-1][1][0],sol_at_MT.ys[-1][2][0],sol_at_MT.ys[-1][3][0],sol_at_MT.ys[-1][4][0],sol_at_MT.ys[-1][5][0],sol_at_MT.ys[-1][6][0],sol_at_MT.ys[-1][7][0]

    Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Yn_MT_f, Yp_MT_f, Yd_MT_f,Yt_MT_f,YHe3_MT_f,Ya_MT_f,YLi7_MT_f, YBe7_MT_f
    return jnp.array([Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f])
Yn_f,Yp_f,Yd_f,Yt_f,YHe3_f,Ya_f,YLi7_f,YBe7_f = Results()
print(Yn_f)

It seems diffrax successfully solves the differential equation, but struggles to return the output, i.e. it seems the code hangs when trying to assign values to the variable sol_at_MT. Tampering a bit with the diffrax source, it looks like there are two things going on.

One is that, no matter what I try to return (even if I set all of the returns to None), if the lines right before the return in integrate.py

branched_error_if(
    throw & jnp.invert(is_okay(result)),
    error_index,
    RESULTS.reverse_lookup,
)

aren’t commented out, the code will freeze. I can include a print statement right after these lines (just before the return) that prints out successfully even when they’re not commented, but I can’t assign anything to sol_at_MT in without the code hanging if these lines are left in.

Then, if I comment that branched_error_if() call out, the code still hangs if I try to return ts, ys, stats or result from integrate.py. This doesn’t seem to be an issue of time or memory; the code just freezes up and can’t even be aborted from the command line whether I’m running locally or with extra resources on a cluster.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:12 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
cgiovanetticommented, Sep 1, 2022

Sorry it’s taken so long for me to get back to you, just wanted to be sure this was also helping with the larger set of DEs I’m trying to solve. This is helpful, thanks!

1reaction
patrick-kidgercommented, Aug 26, 2022

It turns out that what is going on here is this:

  1. The initial step size dt0 that is being passed into diffrax.diffeqsolve is very large. When the solver makes its first numerical step, using this step size, we end up in a nonphysical region of space for which y[2] is negative.
  2. Too-large step sizes normally aren’t a problem, as the stepsize controller will just reject these. Except the vector field for this system is sort-of-questionably-defined: it includes a square root, in particular jnp.sqrt(y[2]). And jnp.sqrt(something negative) produces a NaN.
  3. Thus instead of seeing “this error is too large”, the stepsize controller just sees a NaN, which it doesn’t know how to handle, and the whole thing chokes.

As a general rule, you shouldn’t use vector fields that are capable of returning NaNs: the NaN can propagate and break things in strange ways, as it has done here. (In this case, both the sqrt and the / are pretty suspicious.)

It’s actually pretty hard to come up with a sensible policy for handling NaNs in all the various cases they might arise. Still, at least for this case it’s clear that we want to reject the step – I’ll submit a PR soon so that diffrax.PIDController rejects steps in the event of a NaN.

For posterity, a couple of other solutions to this problem (right now, even without this upcoming PR):

  1. Pass either dt0=None or dt0=something small: in this case we never end up in an unphysical region of space, things work as normal, and both Diffrax and solve_ivp obtain identical solutions.
  2. Ensure that your vector field is never capable of returning a NaN. In your case, out = jnp.stack(...); return jnp.where(jnp.isnan(out), jnp.inf, out) will suffice.
Read more comments on GitHub >

github_iconTop Results From Across the Web

How to solve a coupled differential equations
Go to the complex plane. Do "first" + i"second" equation and you'll get something elegant. Sum of equations with second multiplied by the ......
Read more >
How to Solve Coupled Differential Equations ODEs in Python
... odeint method within Python to solve coupled Ordinary Differential Equations ( ODEs ) and plot the ... Your browser can't play this...
Read more >
Coupled First Order Simultaneous Differential Equations [Year ...
An A Level Maths Revision Tutorial on Coupled First Order Simultaneous Differential Equations. ... Your browser can't play this video.
Read more >
Solving a system of coupled differential equations with ...
Note that in general nonlinear systems of ODEs are very unlikely to ... Since we can't get an expression for the solution for...
Read more >
I have three coupled differential equation and need help to ...
Because of the presence of t in the denominator, the system is not defined (appoaches ∞) at . The system likely has no...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found