question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Hey! I’d like to propose that apart from the current registration mechanism JAX allowed users to define pytrees via a Pytree Protocol. There are two reasons for doing this:

  1. Simplifies the definition of custom Pytrees. Currently the best mechanism is to ensure that a class and all its subclasses are registered is to override the __init_subclass__ method and perform the registration there, implementing a protocol is more straightforward as __init_subclass__ is not widely known.
  2. As the concept of a Pytree becomes more wide spread within the Python ecosystem (e.g. pytorch/pytorch#65761, and dm-tree) a protocol could be a simple way of getting cross-library compatibility if the different implementations adopt it.

Implementation

The idea would be to take the same mechanism implemented in register_pytree_node_class but with “special methods”:

class Pytree(Protocol):
    def __tree_flatten__(self) -> Tuple[Sequence[Any], Any]:
        ...
    
    @classmethod
    def __tree_unflatten__(cls, children: Sequence[Any], aux: Any) -> Any:
        ...

Any object that defines these method would be treated as a Pytree at runtime.

Other comments/ideas

  • Pytree seems like a general concept orthogonal to numeric computing, making JAX’s Pytree implementation a separate project that JAX depended on would have a positive impact on the whole Python ecosystem.
  • Libraries that focus on general Pytree manipulation like Treeo could start to be useful independent of JAX.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:8 (5 by maintainers)

github_iconTop GitHub Comments

5reactions
hawkinspcommented, Oct 6, 2021

That’s a good idea, especially if libraries other than JAX are going to start using pytrees.

One thing I will mention is that I don’t think it makes sense to have a single global definition of what a pytree is. JAX pytrees have semantics that make sense for JAX APIs. But one can imagine that one might want to treat a class as a container for JAX, but treat it as an opaque object for some other API or vice versa. So I suspect __tree_flatten__ might need to receive an object which is a description of the flattener (yet to exist) so it can decide whether it wants to be flattened by that API.

I suspect this will become clearer as part of discussions about perhaps splitting pytree into its own library.

Another fix: please flip the order of aux and children in __tree_unflatten__. It has always bothered me that JAX’s flatten and unflatten have different orders, but I haven’t wanted the disruption of changing it.

2reactions
patrick-kidgercommented, Oct 10, 2021

@hawkinsp mentions sorting out the order of arguments to tree_unflatten. If we’re willing to make a compatibility break like that, then there’s a couple of other things that I think might be desirable to sort out at the same time:

“PyDags” (c.f. #7919) At present unflatten(flatten(...)) is not the identity function.

It will create separate copies of elements that appear multiple times in the tree. There are times when this is undesirable; a notable one is when using PyTrees to represent parameterised functions – e.g. a PyTorch-like Module system – a la Equinox, Treex etc.

In principle I don’t think this should break any existing PyTree uses: these assume referential transparency so they should be agnostic to whether they get a copy of each leaf/subtree or not.

No auxiliary data This is something I’ve been thinking about for a while now. My current belief is that the auxiliary data output is actually unnecessary.

I think this is muddling up two separate notions: (a) partitioning a PyTree into two (or even n) pieces; (b) (un)flattening a PyTree. A simpler/neater API is to have everything be a leaf, and then invoke some tree1, tree2 = partition(tree, partition_rule) – or even some tree1, tree2, tree3 = partition(...) etc. – and then pass tree1 and tree2 separately across an API boundary. Each of tree1, tree2 are PyTrees with the same structure as tree; each contain some subset of the leaves and have some dummy value for the leaves they don’t have.

The above discussion on custom flatteners corresponds to a choice of partition rule.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Pytrees - JAX documentation - Read the Docs
In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects. Classes are considered container-like...
Read more >
py_trees Documentation - Read the Docs
In this package these patterns will be referred to as PyTree ... Note: Many of these guidelines we've evolved from trial and error...
Read more >
PyTree: A Generic Tree Object Viewer - O'Reilly
This program is called PyTree, a generic tree data structure viewer written in Python with the Tkinter GUI library. PyTree sketches out the...
Read more >
PyTree: A Generic Tree Object Viewer | Data Structures
PyTree supports arbitrary tree types by "wrapping" real trees in interface objects. The interface objects implement a standard protocol by communicating with ...
Read more >
PyTree - Devpost
PyTree - PyTree, a PyTorch package for recursive neural networks.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found