vmap using in_axes doesn't handle named arguments
See original GitHub issuePassing keyword arguments to a function vectorized with jax.vmap(in_axes=..., out_axes=...)
does not seem to work and results in an AssertionError
.
For example:
import jax
import jax.numpy as jnp
def f(a, b, c):
return (2*a, 3*b + c)
print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(jnp.array([1, 2]), jnp.array([2, 4]), 0.5)) # works
# (DeviceArray([2, 4], dtype=int32), DeviceArray([ 6.5, 12.5], dtype=float32))
print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(a=jnp.array([1, 2]), b=jnp.array([2, 4]), c=0.5)) # doesn't work
results in:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[... skipping hidden 1 frame]
~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/_src/tree_util.py in tree_map(f, tree, is_leaf, *rest)
166 leaves, treedef = tree_flatten(tree, is_leaf)
--> 167 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
168 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/_src/tree_util.py in <listcomp>(.0)
166 leaves, treedef = tree_flatten(tree, is_leaf)
--> 167 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
168 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
ValueError: Tuple arity mismatch: 0 != 3; tuple: ().
During handling of the above exception, another exception occurred:
AssertionError Traceback (most recent call last)
/tmp/ipykernel_218392/621472192.py in <module>
8 # (DeviceArray([2, 4], dtype=int32), DeviceArray([ 6.5, 12.5], dtype=float32))
9
---> 10 print(jax.vmap(f, in_axes=(0, 0, None), out_axes=0)(a=jnp.array([1, 2]), b=jnp.array([2, 4]), c=0.5)) # doesn't work
[... skipping hidden 2 frame]
~/miniconda3/envs/lcms/lib/python3.9/site-packages/jax/api_util.py in flatten_axes(name, treedef, axis_tree, kws)
274 # message only to be about the positional arguments
275 treedef, leaf = treedef_children(treedef)
--> 276 assert treedef_is_leaf(leaf)
277 axis_tree, _ = axis_tree
278 raise ValueError(f"{name} specification must be a tree prefix of the "
AssertionError:
Version used: jax: 0.2.18
Issue Analytics
- State:
- Created 2 years ago
- Reactions:5
- Comments:8 (1 by maintainers)
Top Results From Across the Web
in_axes keyword in JAX's vmap - auto vectorization
in_axes =(None, 0) means that the first argument (here params ) will not be mapped, while the second argument (here input_vec ) will...
Read more >Named axes and easy-to-revise parallelism
Named axes and easy-to-revise parallelism#. This tutorial introduces jax.xmap and the named-axis programming model that comes with it.
Read more >Object orientation - The Apache Groovy programming language
In other words, Groovy does not define structural typing. ... When the first argument is a Map, Groovy combines all named parameters into...
Read more >Using MATLAB Graphics
Linking Graphs to Variables — Data Source Properties . . . . . 8-19 ... to get the handle using MATLAB commands, because...
Read more >Parsing arguments and building values — Python 3.11.1 ...
With a few exceptions, a format unit that is not a parenthesized sequence normally corresponds to a single address argument to these functions....
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I was hit by this today. Can we have an explicit error while the behavior is not fixed perhaps?
Got the same problem today. The empty error message makes it very difficult to understand, by chance I had a snippet that did not contain named arguments so I could compare my code, otherwise it would have been very difficult.