question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Documentation for passing in trees

See original GitHub issue

I would like to pass a dictionary through a vmap.

I made an attempt to understand https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html but couldn’t figure out if this is something I need to understand or not?

vmap(loss, [None, 0, 0])(dictionary, X[i:i+batch], y[i:i+batch]) vmap(loss, [None, 0, 0])(tree_flatten(dictionary), X[i:i+batch], y[i:i+batch]) vmap(loss, [tree_flatten(dict_none), 0, 0])(tree_flatten(dictionary), X[i:i+batch], y[i:i+batch])

Either way I get an issue like this:

ValueError: Expected list, got (([<object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>], <object object at 0x7fcb21ef58b0>), <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>).

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/jax/api.py in _flatten_axes(treedef, axis_tree)
    714     msg = ("axes specification must be a tree prefix of the corresponding "
    715            "value, got specification {} for value {}.")
--> 716     raise ValueError(msg.format(axis_tree, treedef))
    717   axes = [None if a is proxy else a for a in axes]
    718   assert len(axes) == treedef.num_leaves

ValueError: axes specification must be a tree prefix of the corresponding value, got specification [([], PyTreeDef(dict[['dense1', 'dense2', 'dense3']], [PyTreeDef(dict[[]], []),PyTreeDef(dict[[]], []),PyTreeDef(dict[[]], [])])), 0, 0] for value PyTreeDef(tuple, [PyTreeDef(tuple, [PyTreeDef(list, [*,*,*,*,*,*]),*]),*,*]).

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Mar 10, 2020

Thanks for raising this, Sasha!

We don’t want to make you understand the details of pytrees; we want vmap and all the JAX transformations to “just work” on (nested) standard Python containers. In particular, you should not need to read the pytrees docs (and it does say in bold at the top, “This is primarily JAX internal documentation, end-users are not supposed to need to understand this to use JAX, except when registering new user-defined container types with JAX.”)

To more precisely answer your question, we might need a runnable repro. But this works:

import jax.numpy as np
from jax import vmap

dictionary = {'a': 5., 'b': np.ones(2)}
x = np.zeros(3)
y = np.arange(3.)


def f(dct, x, y):
  return dct['a'] + dct['b'] + x + y

result = vmap(f, (None, 0, 0))(dictionary, x, y)

Yet, understandably surprisingly, this doesn’t!

result = vmap(f, [None, 0, 0])(dictionary, x, y)

# ValueError: axes specification must be a tree prefix
# of the corresponding value, got specification [None, 0, 0]
# for value PyTreeDef(tuple, [PyTreeDef(dict[['a', 'b']], [*,*]),*,*]).

The issue is that the axis specification has to be a tree prefix of the args tuple, meaning an int (i.e. a kind of pytree, a leaf), a tuple (because args is a tuple), or a tuple of pytrees (including a tuple of ints, or a tuple of other kinds of pytrees). It can’t be a list!

I think this behavior is surprising because we’re so used to treating lists and tuples interchangeably in Python APIs, like we treat 'foo' and "foo" string quoting interchangeably.

I want to fix this!

0reactions
cottrellcommented, Sep 3, 2021

Will ask in a separate thread but it seems the docs are missing entries on how to create pytrees (from existing ones) for example to manipulate params.

For example, if you want to see sensitivities w.r.t. params.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Tree Topping: What It Is, Why It's Bad and How To Prevent It
Tree topping is the drastic removal, or cutting back, of large branches in mature trees, leaving large, open wounds which subject the tree...
Read more >
Chapter: Trees - ROOT - CERN
WARNING: This documentation is not maintained anymore. Some part might be obsolete or wrong, some part might be missing but still some valuable...
Read more >
Tree Manipulation and Restructuring - DendroPy
This rerooting is a structural change that will require the splits hashes to be updated before performing any tree comparisons or calculating tree...
Read more >
Module API — py_trees 2.1.6 documentation
A parent class for all user definable tree behaviours. ... This differs from the client get method in that it doesn't pass through...
Read more >
Trees — OpenTSDB 2.4 documentation
Along with metadata, OpenTSDB 2.0 introduces the concept of trees, ... This will scan through the tsdb-uid table and pass each discovered TSMeta...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found