Documentation for passing in trees
See original GitHub issueI would like to pass a dictionary through a vmap.
I made an attempt to understand https://jax.readthedocs.io/en/latest/notebooks/JAX_pytrees.html but couldn’t figure out if this is something I need to understand or not?
vmap(loss, [None, 0, 0])(dictionary, X[i:i+batch], y[i:i+batch])
vmap(loss, [None, 0, 0])(tree_flatten(dictionary), X[i:i+batch], y[i:i+batch])
vmap(loss, [tree_flatten(dict_none), 0, 0])(tree_flatten(dictionary), X[i:i+batch], y[i:i+batch])
Either way I get an issue like this:
ValueError: Expected list, got (([<object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>], <object object at 0x7fcb21ef58b0>), <object object at 0x7fcb21ef58b0>, <object object at 0x7fcb21ef58b0>).
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/jax/api.py in _flatten_axes(treedef, axis_tree)
714 msg = ("axes specification must be a tree prefix of the corresponding "
715 "value, got specification {} for value {}.")
--> 716 raise ValueError(msg.format(axis_tree, treedef))
717 axes = [None if a is proxy else a for a in axes]
718 assert len(axes) == treedef.num_leaves
ValueError: axes specification must be a tree prefix of the corresponding value, got specification [([], PyTreeDef(dict[['dense1', 'dense2', 'dense3']], [PyTreeDef(dict[[]], []),PyTreeDef(dict[[]], []),PyTreeDef(dict[[]], [])])), 0, 0] for value PyTreeDef(tuple, [PyTreeDef(tuple, [PyTreeDef(list, [*,*,*,*,*,*]),*]),*,*]).
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:7 (3 by maintainers)
Top Results From Across the Web
Tree Topping: What It Is, Why It's Bad and How To Prevent It
Tree topping is the drastic removal, or cutting back, of large branches in mature trees, leaving large, open wounds which subject the tree...
Read more >Chapter: Trees - ROOT - CERN
WARNING: This documentation is not maintained anymore. Some part might be obsolete or wrong, some part might be missing but still some valuable...
Read more >Tree Manipulation and Restructuring - DendroPy
This rerooting is a structural change that will require the splits hashes to be updated before performing any tree comparisons or calculating tree...
Read more >Module API — py_trees 2.1.6 documentation
A parent class for all user definable tree behaviours. ... This differs from the client get method in that it doesn't pass through...
Read more >Trees — OpenTSDB 2.4 documentation
Along with metadata, OpenTSDB 2.0 introduces the concept of trees, ... This will scan through the tsdb-uid table and pass each discovered TSMeta...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for raising this, Sasha!
We don’t want to make you understand the details of pytrees; we want
vmap
and all the JAX transformations to “just work” on (nested) standard Python containers. In particular, you should not need to read the pytrees docs (and it does say in bold at the top, “This is primarily JAX internal documentation, end-users are not supposed to need to understand this to use JAX, except when registering new user-defined container types with JAX.”)To more precisely answer your question, we might need a runnable repro. But this works:
Yet, understandably surprisingly, this doesn’t!
The issue is that the axis specification has to be a tree prefix of the
args
tuple, meaning an int (i.e. a kind of pytree, a leaf), a tuple (becauseargs
is a tuple), or a tuple of pytrees (including a tuple of ints, or a tuple of other kinds of pytrees). It can’t be a list!I think this behavior is surprising because we’re so used to treating lists and tuples interchangeably in Python APIs, like we treat
'foo'
and"foo"
string quoting interchangeably.I want to fix this!
Will ask in a separate thread but it seems the docs are missing entries on how to create pytrees (from existing ones) for example to manipulate params.
For example, if you want to see sensitivities w.r.t. params.