question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

AssertionError when taking grad of odeint with outer scope variable

See original GitHub issue

Hello!

An AssertionError arises when taking the grad of odeint with a variable which is outer scope.

My setup is MacBook Pro 13 inch 2019 with MacOS Catalina 10.15.2. I have compiled jax and jaxlib from source on the current master branch.

Reproduction code (x is the outer scope variable):

from jax.experimental.ode import odeint
from jax import jit, grad, value_and_grad, vmap
import jax.numpy as np

@jit
def experiment(x):

    def model(y, t):
        dydt = -x * y
        return dydt

    history = odeint(model, 1., np.arange(0, 10, 0.1))
    return history[-1]

experiment = value_and_grad(experiment)
t = np.arange(0., 1., 0.01)
h, g = vmap(experiment)(t)

Running this gives the following output:

Traceback (most recent call last):
  File "/Users/kstorm/PycharmProjects/finger_model/venv/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-a7e55fc52a04>", line 1, in <module>
    runfile('/Users/kstorm/PycharmProjects/finger_model/issue.py', wdir='/Users/kstorm/PycharmProjects/finger_model')
  File "/Users/kstorm/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/193.5233.109/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Users/kstorm/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/193.5233.109/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/kstorm/PycharmProjects/finger_model/issue.py", line 22, in <module>
    h, g = vmap(experiment)(t)
  File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/api.py", line 759, in batched_fun
    lambda: _flatten_axes(out_tree(), out_axes))
  File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/interpreters/batching.py", line 34, in batch
    return batched_fun.call_wrapped(*in_vals)
  File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/api.py", line 428, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
  File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/api.py", line 1389, in _vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
  File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/interpreters/ad.py", line 106, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/interpreters/ad.py", line 97, in linearize
    assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
AssertionError

I know the problem can be solved by passing x as an argument to odeint and adding an argument to ‘def model’ like this:

@jit
def experiment(x):

    def model(y, t, a):
        dydt = -a * y
        return dydt

    history = odeint(model, 1., np.arange(0, 10, 0.1), x)
    return history[-1]`

But I wonder why it doesn’t work with outer scope variables.

Thanks in advance!

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:3
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
NeilGirdharcommented, Apr 19, 2020

@mattjj You got it! I forgot to add my vjp. Thanks for the lightning fast response. I really appreciate it!

1reaction
mattjjcommented, Apr 19, 2020

@NeilGirdhar are you differentiating through a while_loop? See #2129. Otherwise you’ll have to give us more hints!

@killianstorm To add on to what @froystig said: jax.experimental.ode.odeint is not actually a primitive (like the control flow primitives lax.cond, lax.scan, etc. which correctly handle closed-over tracers in their function-valued arguments). It’s just a function with a jax.custom_vjp rule defined. And higher-order functions with jax.custom_vjp rules can’t automatically handle closed-over tracers in their function-valued arguments.

We should document this constraint on jax.custom_vjp functions better (hence the “documentation” tag on this issue), but it’s mentioned briefly in the last example (a fixed_point function) in the custom_vjp/jvp tutorial, specifically in the paragraph just before this heading.

Actually, the reason why odeint still lives in jax.experimental rather than being included directly in lax or something is precisely that we want to upgrade it to handle closed-over tracers automatically, like the other functions/primitives in lax which take function-valued arguments. We just haven’t done it yet! When we do, it’ll graduate out of jax.experimental.

Read more comments on GitHub >

github_iconTop Results From Across the Web

R variable scope using gradient function - Stack Overflow
I got error message argument is missing with no default when using the gradient function. It seems that variables are not passed to...
Read more >
Python Scientific lecture notes - Inspirit
The variable x only exists within the function foo. 2.5.5 Global variables. Variables declared outside the function can be referenced within the function:....
Read more >
Python for Computational Science and Engineering
This text summarises a number of core ideas relevant to Computational Engineering and Scientific. Computing using Python. The emphasis is on introducing ...
Read more >
Mailman 3 September 2011 - SciPy-User - python.org
Hello All, I have been encountering some strange problems with the calculation of eigen vectors using python. The problem is with the sign...
Read more >
Scipy Lectures | PDF | Parameter (Computer Programming) | Matlab
Variables declared outside the function can be referenced within the function: ... scipy.integrate.odeint() is a general-purpose integrator using LSODA ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found