Pytree Protocol
See original GitHub issueHey! I’d like to propose that apart from the current registration mechanism JAX allowed users to define pytrees via a Pytree Protocol. There are two reasons for doing this:
- Simplifies the definition of custom Pytrees. Currently the best mechanism is to ensure that a class and all its subclasses are registered is to override the
__init_subclass__
method and perform the registration there, implementing a protocol is more straightforward as__init_subclass__
is not widely known. - As the concept of a Pytree becomes more wide spread within the Python ecosystem (e.g. pytorch/pytorch#65761, and
dm-tree
) a protocol could be a simple way of getting cross-library compatibility if the different implementations adopt it.
Implementation
The idea would be to take the same mechanism implemented in register_pytree_node_class
but with “special methods”:
class Pytree(Protocol):
def __tree_flatten__(self) -> Tuple[Sequence[Any], Any]:
...
@classmethod
def __tree_unflatten__(cls, children: Sequence[Any], aux: Any) -> Any:
...
Any object that defines these method would be treated as a Pytree at runtime.
Other comments/ideas
- Pytree seems like a general concept orthogonal to numeric computing, making JAX’s Pytree implementation a separate project that JAX depended on would have a positive impact on the whole Python ecosystem.
- Libraries that focus on general Pytree manipulation like Treeo could start to be useful independent of JAX.
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (5 by maintainers)
Top Results From Across the Web
Pytrees - JAX documentation - Read the Docs
In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects. Classes are considered container-like...
Read more >py_trees Documentation - Read the Docs
In this package these patterns will be referred to as PyTree ... Note: Many of these guidelines we've evolved from trial and error...
Read more >PyTree: A Generic Tree Object Viewer - O'Reilly
This program is called PyTree, a generic tree data structure viewer written in Python with the Tkinter GUI library. PyTree sketches out the...
Read more >PyTree: A Generic Tree Object Viewer | Data Structures
PyTree supports arbitrary tree types by "wrapping" real trees in interface objects. The interface objects implement a standard protocol by communicating with ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
That’s a good idea, especially if libraries other than JAX are going to start using pytrees.
One thing I will mention is that I don’t think it makes sense to have a single global definition of what a pytree is. JAX pytrees have semantics that make sense for JAX APIs. But one can imagine that one might want to treat a class as a container for JAX, but treat it as an opaque object for some other API or vice versa. So I suspect
__tree_flatten__
might need to receive an object which is a description of the flattener (yet to exist) so it can decide whether it wants to be flattened by that API.I suspect this will become clearer as part of discussions about perhaps splitting pytree into its own library.
Another fix: please flip the order of
aux
andchildren
in__tree_unflatten__
. It has always bothered me that JAX’s flatten and unflatten have different orders, but I haven’t wanted the disruption of changing it.@hawkinsp mentions sorting out the order of arguments to
tree_unflatten
. If we’re willing to make a compatibility break like that, then there’s a couple of other things that I think might be desirable to sort out at the same time:“PyDags” (c.f. #7919) At present
unflatten(flatten(...))
is not the identity function.It will create separate copies of elements that appear multiple times in the tree. There are times when this is undesirable; a notable one is when using PyTrees to represent parameterised functions – e.g. a PyTorch-like Module system – a la Equinox, Treex etc.
In principle I don’t think this should break any existing PyTree uses: these assume referential transparency so they should be agnostic to whether they get a copy of each leaf/subtree or not.
No auxiliary data This is something I’ve been thinking about for a while now. My current belief is that the auxiliary data output is actually unnecessary.
I think this is muddling up two separate notions: (a) partitioning a PyTree into two (or even
n
) pieces; (b) (un)flattening a PyTree. A simpler/neater API is to have everything be a leaf, and then invoke sometree1, tree2 = partition(tree, partition_rule)
– or even sometree1, tree2, tree3 = partition(...)
etc. – and then passtree1
andtree2
separately across an API boundary. Each oftree1
,tree2
are PyTrees with the same structure astree
; each contain some subset of the leaves and have some dummy value for the leaves they don’t have.The above discussion on custom flatteners corresponds to a choice of partition rule.