AssertionError when taking grad of odeint with outer scope variable
See original GitHub issueHello!
An AssertionError arises when taking the grad of odeint with a variable which is outer scope.
My setup is MacBook Pro 13 inch 2019 with MacOS Catalina 10.15.2. I have compiled jax and jaxlib from source on the current master branch.
Reproduction code (x is the outer scope variable):
from jax.experimental.ode import odeint
from jax import jit, grad, value_and_grad, vmap
import jax.numpy as np
@jit
def experiment(x):
def model(y, t):
dydt = -x * y
return dydt
history = odeint(model, 1., np.arange(0, 10, 0.1))
return history[-1]
experiment = value_and_grad(experiment)
t = np.arange(0., 1., 0.01)
h, g = vmap(experiment)(t)
Running this gives the following output:
Traceback (most recent call last):
File "/Users/kstorm/PycharmProjects/finger_model/venv/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-a7e55fc52a04>", line 1, in <module>
runfile('/Users/kstorm/PycharmProjects/finger_model/issue.py', wdir='/Users/kstorm/PycharmProjects/finger_model')
File "/Users/kstorm/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/193.5233.109/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Users/kstorm/Library/Application Support/JetBrains/Toolbox/apps/PyCharm-P/ch-0/193.5233.109/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/kstorm/PycharmProjects/finger_model/issue.py", line 22, in <module>
h, g = vmap(experiment)(t)
File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/api.py", line 759, in batched_fun
lambda: _flatten_axes(out_tree(), out_axes))
File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/interpreters/batching.py", line 34, in batch
return batched_fun.call_wrapped(*in_vals)
File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/linear_util.py", line 150, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/api.py", line 428, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/api.py", line 1389, in _vjp
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/interpreters/ad.py", line 106, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/Users/kstorm/PycharmProjects/finger_model/jax/jax/interpreters/ad.py", line 97, in linearize
assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
AssertionError
I know the problem can be solved by passing x as an argument to odeint and adding an argument to ‘def model’ like this:
@jit
def experiment(x):
def model(y, t, a):
dydt = -a * y
return dydt
history = odeint(model, 1., np.arange(0, 10, 0.1), x)
return history[-1]`
But I wonder why it doesn’t work with outer scope variables.
Thanks in advance!
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:5 (4 by maintainers)
Top Results From Across the Web
R variable scope using gradient function - Stack Overflow
I got error message argument is missing with no default when using the gradient function. It seems that variables are not passed to...
Read more >Python Scientific lecture notes - Inspirit
The variable x only exists within the function foo. 2.5.5 Global variables. Variables declared outside the function can be referenced within the function:....
Read more >Python for Computational Science and Engineering
This text summarises a number of core ideas relevant to Computational Engineering and Scientific. Computing using Python. The emphasis is on introducing ...
Read more >Mailman 3 September 2011 - SciPy-User - python.org
Hello All, I have been encountering some strange problems with the calculation of eigen vectors using python. The problem is with the sign...
Read more >Scipy Lectures | PDF | Parameter (Computer Programming) | Matlab
Variables declared outside the function can be referenced within the function: ... scipy.integrate.odeint() is a general-purpose integrator using LSODA ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@mattjj You got it! I forgot to add my vjp. Thanks for the lightning fast response. I really appreciate it!
@NeilGirdhar are you differentiating through a
while_loop
? See #2129. Otherwise you’ll have to give us more hints!@killianstorm To add on to what @froystig said:
jax.experimental.ode.odeint
is not actually a primitive (like the control flow primitiveslax.cond
,lax.scan
, etc. which correctly handle closed-over tracers in their function-valued arguments). It’s just a function with ajax.custom_vjp
rule defined. And higher-order functions withjax.custom_vjp
rules can’t automatically handle closed-over tracers in their function-valued arguments.We should document this constraint on
jax.custom_vjp
functions better (hence the “documentation” tag on this issue), but it’s mentioned briefly in the last example (afixed_point
function) in the custom_vjp/jvp tutorial, specifically in the paragraph just before this heading.Actually, the reason why
odeint
still lives injax.experimental
rather than being included directly inlax
or something is precisely that we want to upgrade it to handle closed-over tracers automatically, like the other functions/primitives inlax
which take function-valued arguments. We just haven’t done it yet! When we do, it’ll graduate out ofjax.experimental
.