vmap is incompatible with custom pytree
See original GitHub issueThe following repro code
from jax import tree_util, lax, vmap, numpy as np, jit, pmap
class Special(object):
def __init__(self, x, y):
shape = lax.broadcast_shapes(np.shape(x), np.shape(y))
self.x = np.broadcast_to(x, shape)
self.y = np.broadcast_to(y, shape)
def special_flatten(v):
return ((v.x, v.y), None)
def special_unflatten(aux_data, children):
return Special(*children)
tree_util.register_pytree_node(Special, special_flatten, special_unflatten)
def f(x):
return Special(x, x)
assert jit(f)(np.ones(3)).x.shape == (3,)
assert lax.map(f, np.ones(3)).x.shape == (3,)
assert pmap(f)(np.ones(3)).x.shape == (3,)
vmap(f)(np.ones(3)) # fail!
triggers the error
Details
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-36-e7c65535399a> in <module>
18 assert jit(f)(np.ones(3)).x.shape == (2, 2, 3)
19 assert lax.map(f, np.ones(3)).x.shape == (2, 3, 2)
---> 20 vmap(f)(np.ones(3))
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py in batched_fun(*args)
768 _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
769 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 770 lambda: flatten_axes(out_tree(), out_axes))
771 return tree_unflatten(out_tree(), out_flat)
772
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
32 # executes a batched version of `fun` following out_dim_dests
33 batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34 return batched_fun.call_wrapped(*in_vals)
35
36 @lu.transformation_with_aux
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(***failed resolving arguments***)
152 while stack:
153 gen, out_store = stack.pop()
--> 154 ans = gen.send(ans)
155 if out_store is not None:
156 ans, side = ans
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/batching.py in _batch_fun(sum_match, in_dims, out_dims_thunk, out_dim_dests, *in_vals, **params)
57 out_vals = yield (master, in_dims,) + in_vals, params
58 del master
---> 59 out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
60 out_dims = out_dims_thunk()
61 for od, od_dest in zip(out_dims, out_dim_dests):
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py in <lambda>()
768 _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
769 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 770 lambda: flatten_axes(out_tree(), out_axes))
771 return tree_unflatten(out_tree(), out_flat)
772
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api_util.py in flatten_axes(treedef, axis_tree)
106 # TODO(mattjj,phawkins): improve this implementation
107 proxy = object()
--> 108 dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
109 axes = []
110 add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/tree_util.py in tree_unflatten(treedef, leaves)
68 structure described by `treedef`.
69 """
---> 70 return treedef.unflatten(leaves)
71
72 def tree_leaves(tree):
<ipython-input-36-e7c65535399a> in special_unflatten(aux_data, children)
9
10 def special_unflatten(aux_data, children):
---> 11 return Special(*children)
12
13 tree_util.register_pytree_node(Special, special_flatten, special_unflatten)
<ipython-input-36-e7c65535399a> in __init__(self, x)
3 class Special(object):
4 def __init__(self, x):
----> 5 self.x = np.broadcast_to(x, (2,) + np.shape(x))
6
7 def special_flatten(v):
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/numpy/lax_numpy.py in broadcast_to(arr, shape)
1280 def broadcast_to(arr, shape):
1281 """Like Numpy's broadcast_to but doesn't necessarily return views."""
-> 1282 arr = arr if isinstance(arr, ndarray) else array(arr)
1283 shape = canonicalize_shape(shape) # check that shape is concrete
1284 arr_shape = _shape(arr)
~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
2070 return array(onp.asarray(view), dtype, copy)
2071
-> 2072 raise TypeError("Unexpected input type for array: {}".format(type(object)))
2073
2074 if ndmin > ndim(out):
TypeError: Unexpected input type for array: <class 'object'>
The issue here is we are applying np.broadcast_to
to an “object” x
(instead of a ndarray) when unflatenning the tree under vmap. Is there any workaround available in JAX?
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
Working with Pytrees - JAX documentation - Read the Docs
Custom pytree nodes#. So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if ......
Read more >How to get keys for jax.tree_flatten object? - Stack Overflow
In a way, the question is ill-posed: tree flattening is applicable to a far more general class of objects than nested dicts as...
Read more >Markov Chain Monte Carlo (MCMC) - NumPyro documentation
The data type is a dict keyed on site names if a model containing Pyro primitives is used, but can be any jaxlib.pytree()...
Read more >Why You Should (or Shouldn't) be Using Google's JAX in 2022
For vector-valued functions which map vectors to vectors, the analogue to ... If most of your work is in Python using a lot...
Read more >NumPyro Documentation - Read the Docs
MCMC inference API, effect handlers and writing custom inference ... that any effect handlers or distributions are unsupported, please file an issue.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I think this issue is resolved, but if not please let us know @fehiepsi ! By the way, it’s great to hear from you again 😄
Like I said, you can use some Python magic for an alternative constructor that skips
__init__
:I don’t really recommend it as a design pattern (cheap constructors that just do assignment/validation are generally preferred, for this among other reasons), but it works if you need it for backwards compatibility. If doing it from scratch, I would consider having a
normal()
function that constructs aNormal()
object.