question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

stax.serial.apply_fun is not a valid JAX type inside odeint

See original GitHub issue

Hi, FWIW, I’m using a self-built jax and jaxlib following instructions from #2083.

#
# Name                    Version                   Build  Channel
jax                       0.1.64                    <pip>
jaxlib                    0.1.45                    <pip>

I’m trying to do get gradients through an ODE solver. First, I ran into AssertionError issue #2718 and I think I solved it by passing all the arguments directly into odeint. Then I followed instructions to solve another AssertionError issue #2531 by doing vmap of grads instead of grads of vmap . Now I’m getting the following error.

Full trace back.

----> 1 batch_grad(batch_y0, batch_t, batch_y,[1.3,1.8], [U1,U2], [U1_params,U2_params])

~/Code/jax/jax/api.py in batched_fun(*args)
    805     _check_axis_sizes(in_tree, args_flat, in_axes_flat)
    806     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 807                               lambda: _flatten_axes(out_tree(), out_axes))
    808     return tree_unflatten(out_tree(), out_flat)
    809 

~/Code/jax/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     32   # executes a batched version of `fun` following out_dim_dests
     33   batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34   return batched_fun.call_wrapped(*in_vals)
     35 
     36 @lu.transformation_with_aux

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

~/Code/jax/jax/api.py in value_and_grad_f(*args, **kwargs)
    436     f_partial, dyn_args = argnums_partial(f, argnums, args)
    437     if not has_aux:
--> 438       ans, vjp_py = _vjp(f_partial, *dyn_args)
    439     else:
    440       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

~/Code/jax/jax/api.py in _vjp(fun, *primals, **kwargs)
   1437   if not has_aux:
   1438     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1439     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1440     out_tree = out_tree()
   1441   else:

~/Code/jax/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    104 def vjp(traceable, primals, has_aux=False):
    105   if not has_aux:
--> 106     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    107   else:
    108     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/Code/jax/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     93   _, in_tree = tree_flatten(((primals, primals), {}))
     94   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 95   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     96   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
     97   assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    435   with new_master(trace_type, bottom=bottom) as master:
    436     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 437     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    438     assert not env
    439     del master

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
    152     flat_fun, out_tree = flatten_fun(f, in_tree)
    153     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,
--> 154                        name=flat_fun.__name__)
    155     return tree_unflatten(out_tree(), out)
    156 

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params)
    342     name = params.get('name', f.__name__)
    343     params = dict(params, name=wrap_name(name, 'jvp'))
--> 344     result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **params)
    345     primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
    346     return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    175     in_pvs, in_consts = unzip2([t.pval for t in tracers])
    176     fun, aux = partial_eval(f, self, in_pvs)
--> 177     out_flat = call_primitive.bind(fun, *in_consts, **params)
    178     out_pvs, jaxpr, env = aux()
    179     env_tracers = map(self.full_raise, env)

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
   1003     tracers = map(top_trace.full_raise, args)
   1004     process = getattr(top_trace, processor)
-> 1005     outs = map(full_lower, process(primitive, f, tracers, params))
   1006   return apply_todos(env_trace_todo(), outs)
   1007 

~/Code/jax/jax/interpreters/batching.py in process_call(self, call_primitive, f, tracers, params)
    146     else:
    147       f, dims_out = batch_subtrace(f, self.master, dims)
--> 148       vals_out = call_primitive.bind(f, *vals, **params)
    149       return [BatchTracer(self, v, d) for v, d in zip(vals_out, dims_out())]
    150 

~/Code/jax/jax/core.py in _call_bind(processor, post_processor, primitive, f, *args, **params)
    999   if top_trace is None:
   1000     with new_sublevel():
-> 1001       outs = primitive.impl(f, *args, **params)
   1002   else:
   1003     tracers = map(top_trace.full_raise, args)

~/Code/jax/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, *args)
    460 
    461 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name):
--> 462   compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
    463   try:
    464     return compiled_fun(*args)

~/Code/jax/jax/linear_util.py in memoized_fun(fun, *args)
    219       fun.populate_stores(stores)
    220     else:
--> 221       ans = call(fun, *args)
    222       cache[key] = (ans, fun.stores)
    223     return ans

~/Code/jax/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, *arg_specs)
    477   pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
    478   jaxpr, pvals, consts = pe.trace_to_jaxpr(
--> 479       fun, pvals, instantiate=False, stage_out=True, bottom=True)
    480 
    481   _map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))

~/Code/jax/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate, stage_out, bottom, trace_type)
    435   with new_master(trace_type, bottom=bottom) as master:
    436     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 437     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    438     assert not env
    439     del master

~/Code/jax/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

<ipython-input-17-de50dc731d85> in loss(batch_y0, batch_t, batch_y, params, ufuncs, uparams)
      1 @partial(jit, static_argnums=(4,))
      2 def loss(batch_y0, batch_t, batch_y, params, ufuncs,uparams):
----> 3     pred_y = odeint(batch_y0,batch_t,params,ufuncs,uparams)
      4     loss = np.mean(np.abs(pred_y-batch_y))
      5     return loss

~/Code/jax/jax/experimental/ode.py in odeint(func, y0, t, rtol, atol, mxstep, *args)
    152     shape/structure as `y0` except with a new leading axis of length `len(t)`.
    153   """
--> 154   return _odeint_wrapper(func, rtol, atol, mxstep, y0, t, *args)
    155 
    156 @partial(jax.jit, static_argnums=(0, 1, 2, 3))

~/Code/jax/jax/api.py in f_jitted(*args, **kwargs)
    149       dyn_args = args
    150     args_flat, in_tree = tree_flatten((dyn_args, kwargs))
--> 151     _check_args(args_flat)
    152     flat_fun, out_tree = flatten_fun(f, in_tree)
    153     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend,

~/Code/jax/jax/api.py in _check_args(args)
   1558     if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
   1559       raise TypeError("Argument '{}' of type {} is not a valid JAX type"
-> 1560                       .format(arg, type(arg)))
   1561 
   1562 def _valid_jaxtype(arg):

TypeError: Argument '<function serial.<locals>.apply_fun at 0x2b06c3d6f7a0>' of type <class 'function'> is not a valid JAX type

I’m passing two stax.Serial modules with three Dense layers each as an input to odeint to integrate the Lotka-Volterra ODEs. ufuncs and uparams contains apply functions and params of stax.Serial module.

def lv_UDE(y,t,params,ufuncs,uparams):
    R, F = y
    alpha, theta = params
    U1, U2 = ufuncs
    U1_params, U2_params = uparams
    dRdt = alpha*R - U1(U1_params, y)
    dFdt = -theta*F + U2(U2_params, y)
    return np.array([dRdt,dFdt])

I’m trying to get gradients through an odeint w.r.t uparams. Is there a workaround to pass stax.Serial modules as an argument? Thanks in advance.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:9 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
skrsnacommented, May 2, 2020

Hi @mattjj , I tried your solution and it works seamlessly with vmap. Thanks again.

1reaction
mattjjcommented, May 2, 2020

Hey @skrsna , thanks for the question!

In your example, it seems the lv_UDE is never called. Is that intentional?

The underlying issue here is that odeint can’t take function-valued arguments in *args; those must be arrays (or potentially-nested containers of arrays, like potentially-nested lists/tuples/dicts of arrays). Instead of passing ufuncs via the *args of odeint, maybe you can instead just write something like:

def lv_UDE(ufuncs,y,t,params,uparams):  # moved ufuncs to front
    ...

odeint(partial(lv_UDE, ufuncs), ...)

WDYT?

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.example_libraries.stax module - JAX documentation
A new layer, meaning an (init_fun, apply_fun) pair, representing the serial composition of the given sequence of layers. jax.example_libraries.stax.
Read more >
Argument 'MLP( # attributes num_neurons_per_layer = [4, 1 ...
MLP'> is not a valid JAX type. I'm not sure what I should be doing to correct this. It is throwing an error...
Read more >
Can't run code from introductory tutorial - numpyro
... is not a valid JAX type import numpy as np import jax.numpy as jnp from jax import random, vmap from jax.scipy.special import...
Read more >
https://huggingface.co/jeniya/BERTOverflow/commit/...
... +Al +Java +char +However +variable +param +open +did +write +simple +##01 ... +##rict +invalid +esc +searching +editor +includes +spring +corresponding ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found