question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

vmap is incompatible with custom pytree

See original GitHub issue

The following repro code

from jax import tree_util, lax, vmap, numpy as np, jit, pmap

class Special(object):
    def __init__(self, x, y):
        shape = lax.broadcast_shapes(np.shape(x), np.shape(y))
        self.x = np.broadcast_to(x, shape)
        self.y = np.broadcast_to(y, shape)

def special_flatten(v):
    return ((v.x, v.y), None)

def special_unflatten(aux_data, children):
    return Special(*children)

tree_util.register_pytree_node(Special, special_flatten, special_unflatten)

def f(x):
    return Special(x, x)

assert jit(f)(np.ones(3)).x.shape == (3,)
assert lax.map(f, np.ones(3)).x.shape == (3,)
assert pmap(f)(np.ones(3)).x.shape == (3,)
vmap(f)(np.ones(3))  # fail!

triggers the error

Details
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-36-e7c65535399a> in <module>
     18 assert jit(f)(np.ones(3)).x.shape == (2, 2, 3)
     19 assert lax.map(f, np.ones(3)).x.shape == (2, 3, 2)
---> 20 vmap(f)(np.ones(3))

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py in batched_fun(*args)
    768     _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
    769     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 770                               lambda: flatten_axes(out_tree(), out_axes))
    771     return tree_unflatten(out_tree(), out_flat)
    772 

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     32   # executes a batched version of `fun` following out_dim_dests
     33   batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34   return batched_fun.call_wrapped(*in_vals)
     35 
     36 @lu.transformation_with_aux

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/linear_util.py in call_wrapped(***failed resolving arguments***)
    152     while stack:
    153       gen, out_store = stack.pop()
--> 154       ans = gen.send(ans)
    155       if out_store is not None:
    156         ans, side = ans

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/interpreters/batching.py in _batch_fun(sum_match, in_dims, out_dims_thunk, out_dim_dests, *in_vals, **params)
     57     out_vals = yield (master, in_dims,) + in_vals, params
     58     del master
---> 59   out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
     60   out_dims = out_dims_thunk()
     61   for od, od_dest in zip(out_dims, out_dim_dests):

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api.py in <lambda>()
    768     _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
    769     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 770                               lambda: flatten_axes(out_tree(), out_axes))
    771     return tree_unflatten(out_tree(), out_flat)
    772 

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/api_util.py in flatten_axes(treedef, axis_tree)
    106   # TODO(mattjj,phawkins): improve this implementation
    107   proxy = object()
--> 108   dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
    109   axes = []
    110   add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/tree_util.py in tree_unflatten(treedef, leaves)
     68     structure described by `treedef`.
     69   """
---> 70   return treedef.unflatten(leaves)
     71 
     72 def tree_leaves(tree):

<ipython-input-36-e7c65535399a> in special_unflatten(aux_data, children)
      9 
     10 def special_unflatten(aux_data, children):
---> 11     return Special(*children)
     12 
     13 tree_util.register_pytree_node(Special, special_flatten, special_unflatten)

<ipython-input-36-e7c65535399a> in __init__(self, x)
      3 class Special(object):
      4     def __init__(self, x):
----> 5         self.x = np.broadcast_to(x, (2,) + np.shape(x))
      6 
      7 def special_flatten(v):

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/numpy/lax_numpy.py in broadcast_to(arr, shape)
   1280 def broadcast_to(arr, shape):
   1281   """Like Numpy's broadcast_to but doesn't necessarily return views."""
-> 1282   arr = arr if isinstance(arr, ndarray) else array(arr)
   1283   shape = canonicalize_shape(shape)  # check that shape is concrete
   1284   arr_shape = _shape(arr)

~/miniconda3/envs/pydata/lib/python3.6/site-packages/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
   2070       return array(onp.asarray(view), dtype, copy)
   2071 
-> 2072     raise TypeError("Unexpected input type for array: {}".format(type(object)))
   2073 
   2074   if ndmin > ndim(out):

TypeError: Unexpected input type for array: <class 'object'>

The issue here is we are applying np.broadcast_to to an “object” x (instead of a ndarray) when unflatenning the tree under vmap. Is there any workaround available in JAX?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Jun 2, 2020

I think this issue is resolved, but if not please let us know @fehiepsi ! By the way, it’s great to hear from you again 😄

1reaction
shoyercommented, May 31, 2020

Like I said, you can use some Python magic for an alternative constructor that skips __init__:

from jax import tree_util, lax, vmap, numpy as np, jit, pmap

class Special:
    def __init__(self, x, y):
        shape = lax.broadcast_shapes(np.shape(x), np.shape(y))
        self.x = np.broadcast_to(x, shape)
        self.y = np.broadcast_to(y, shape)
    
    @classmethod
    def restore(cls, x, y):
        obj = object.__new__(cls)
        obj.x = x
        obj.y = y
        return obj

def special_flatten(v):
    return ((v.x, v.y), None)

def special_unflatten(aux_data, children):
    return Special.restore(*children)

tree_util.register_pytree_node(Special, special_flatten, special_unflatten)

def f(x):
    return Special(x, x)

assert jit(f)(np.ones(3)).x.shape == (3,)
assert lax.map(f, np.ones(3)).x.shape == (3,)
assert vmap(f)(np.ones(3)).x.shape == (3,)

I don’t really recommend it as a design pattern (cheap constructors that just do assignment/validation are generally preferred, for this among other reasons), but it works if you need it for backwards compatibility. If doing it from scratch, I would consider having a normal() function that constructs a Normal() object.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Working with Pytrees - JAX documentation - Read the Docs
Custom pytree nodes#. So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if ......
Read more >
How to get keys for jax.tree_flatten object? - Stack Overflow
In a way, the question is ill-posed: tree flattening is applicable to a far more general class of objects than nested dicts as...
Read more >
Markov Chain Monte Carlo (MCMC) - NumPyro documentation
The data type is a dict keyed on site names if a model containing Pyro primitives is used, but can be any jaxlib.pytree()...
Read more >
Why You Should (or Shouldn't) be Using Google's JAX in 2022
For vector-valued functions which map vectors to vectors, the analogue to ... If most of your work is in Python using a lot...
Read more >
NumPyro Documentation - Read the Docs
MCMC inference API, effect handlers and writing custom inference ... that any effect handlers or distributions are unsupported, please file an issue.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found