stax.serial.apply_fun is not a valid JAX type inside odeint
See original GitHub issueHi, FWIW, I’m using a self-built jax and jaxlib following instructions from #2083.
#
# Name Version Build Channel
jax 0.1.64 <pip>
jaxlib 0.1.45 <pip>
I’m trying to do get gradients through an ODE solver. First, I ran into AssertionError
issue #2718 and I think I solved it by passing all the arguments directly into odeint
. Then I followed instructions to solve another AssertionError
issue #2531 by doing vmap
of grads
instead of grads
of vmap
. Now I’m getting the following error.
Full trace back.
----> 1 batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], [U1,U2], [U1_params,U2_params])
~/Code/jax/jax/api.py in batched_fun(*args)
805 _check_axis_sizes(in_tree, args_flat, in_axes_flat)
806 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 807 lambda: _flatten_axes(out_tree(), out_axes))
808 return tree_unflatten(out_tree(), out_flat)
809
~/Code/jax/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
32 # executes a batched version of `fun` following out_dim_dests
33 batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34 return batched_fun.call_wrapped(*in_vals)
35
36 @lu.transformation_with_aux
~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
~/Code/jax/jax/api.py in value_and_grad_f(*args, **kwargs)
436 f_partial, dyn_args = argnums_partial(f, argnums, args)
437 if not has_aux:
--> 438 ans, vjp_py = _vjp(f_partial, *dyn_args)
439 else:
440 ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
~/Code/jax/jax/api.py in _vjp(fun, *primals, **kwargs)
1437 if not has_aux:
1438 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1439 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1440 out_tree = out_tree()
1441 else:
~/Code/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
104 def vjp(traceable, primals, has_aux=False):
105 if not has_aux:
--> 106 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
107 else:
108 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
~/Code/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
93 _, in_tree = tree_flatten(((primals, primals), {}))
94 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 95 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
96 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
97 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
435 with new_master(trace_type, bottom=bottom) as master:
436 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 437 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
438 assert not env
439 del master
~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
152 flat_fun, out_tree = flatten_fun(f, in_tree)
153 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 154 name=flat_fun.__name__)
155 return tree_unflatten(out_tree(), out)
156
~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
1003 tracers = map(top_trace.full_raise, args)
1004 process = getattr(top_trace, processor)
-> 1005 outs = map(full_lower, process(primitive, f, tracers, params))
1006 return apply_todos(env_trace_todo(), outs)
1007
~/Code/jax/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params)
342 name = params.get('name', f.__name__)
343 params = dict(params, name=wrap_name(name, 'jvp'))
--> 344 result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
345 primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
346 return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
1003 tracers = map(top_trace.full_raise, args)
1004 process = getattr(top_trace, processor)
-> 1005 outs = map(full_lower, process(primitive, f, tracers, params))
1006 return apply_todos(env_trace_todo(), outs)
1007
~/Code/jax/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
175 in_pvs, in_consts = unzip2([t.pval for t in tracers])
176 fun, aux = partial_eval(f, self, in_pvs)
--> 177 out_flat = call_primitive.bind(fun, *in_consts, **params)
178 out_pvs, jaxpr, env = aux()
179 env_tracers = map(self.full_raise, env)
~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
1003 tracers = map(top_trace.full_raise, args)
1004 process = getattr(top_trace, processor)
-> 1005 outs = map(full_lower, process(primitive, f, tracers, params))
1006 return apply_todos(env_trace_todo(), outs)
1007
~/Code/jax/jax/interpreters/batching.py in process_call(self, call_primitive, f, tracers, params)
146 else:
147 f, dims_out = batch_subtrace(f, self.master, dims)
--> 148 vals_out = call_primitive.bind(f, *vals, **params)
149 return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
150
~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
999 if top_trace is None:
1000 with new_sublevel():
-> 1001 outs = primitive.impl(f, *args, **params)
1002 else:
1003 tracers = map(top_trace.full_raise, args)
~/Code/jax/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, *args)
460
461 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name):
--> 462 compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
463 try:
464 return compiled_fun(*args)
~/Code/jax/jax/linear_util.py in memoized_fun(fun, *args)
219 fun.populate_stores(stores)
220 else:
--> 221 ans = call(fun, *args)
222 cache[key] = (ans, fun.stores)
223 return ans
~/Code/jax/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, *arg_specs)
477 pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
478 jaxpr, pvals, consts = pe.trace_to_jaxpr(
--> 479 fun, pvals, instantiate=False, stage_out=True, bottom=True)
480
481 _map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))
~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
435 with new_master(trace_type, bottom=bottom) as master:
436 fun = trace_to_subjaxpr(fun, master, instantiate)
--> 437 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
438 assert not env
439 del master
~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
<ipython-input-17-de50dc731d85> in loss(batch_y0, batch_t, batch_y, params, ufuncs, uparams)
1 @partial(jit, static_argnums=(4,))
2 def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
----> 3 pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams)
4 loss = np.mean(np.abs(pred_y-batch_y))
5 return loss
~/Code/jax/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
152 shape/structure as `y0` except with a new leading axis of length `len(t)`.
153 """
--> 154 return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
155
156 @partial(jax.jit, static_argnums=(0, 1, 2, 3))
~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
149 dyn_args = args
150 args_flat, in_tree = tree_flatten((dyn_args, kwargs))
--> 151 _check_args(args_flat)
152 flat_fun, out_tree = flatten_fun(f, in_tree)
153 out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
~/Code/jax/jax/api.py in _check_args(args)
1558 if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
1559 raise TypeError("Argument '{}' of type {} is not a valid JAX type"
-> 1560 .format(arg, type(arg)))
1561
1562 def _valid_jaxtype(arg):
TypeError: Argument '<function serial.<locals>.apply_fun at 0x2b06c3d6f7a0>' of type <class 'function'> is not a valid JAX type
I’m passing two stax.Serial
modules with three Dense
layers each as an input to odeint
to integrate the Lotka-Volterra ODEs. ufuncs
and uparams
contains apply functions and params of stax.Serial
module.
def lv_UDE(y,t,params,ufuncs,uparams):
R, F = y
alpha, theta = params
U1, U2 = ufuncs
U1_params, U2_params = uparams
dRdt = alpha*R - U1(U1_params, y)
dFdt = -theta*F + U2(U2_params, y)
return np.array([dRdt,dFdt])
I’m trying to get gradients through an odeint
w.r.t uparams
. Is there a workaround to pass stax.Serial
modules as an argument? Thanks in advance.
Issue Analytics
- State:
- Created 3 years ago
- Comments:9 (5 by maintainers)
Hi @mattjj , I tried your solution and it works seamlessly with
vmap
. Thanks again.Hey @skrsna , thanks for the question!
In your example, it seems the
lv_UDE
is never called. Is that intentional?The underlying issue here is that
odeint
can’t take function-valued arguments in*args
; those must be arrays (or potentially-nested containers of arrays, like potentially-nested lists/tuples/dicts of arrays). Instead of passingufuncs
via the*args
ofodeint
, maybe you can instead just write something like:WDYT?