Optional base class or class decorator for PyTree objects
See original GitHub issueInstead of requiring users to explicitly call register_pytree_node()
, we could supply a base class that lets them instead define a pair of special methods for flattening/unflattening.
This looks slightly cleaner than calling functions on classes, but still avoids the evils of implementation inheritance.
Here’s a working example, adapted from the pytree docs:
import jax
from jax import tree_util
# baseclass that should go in tree_util.py
def _generic_flatten(tree):
aux, children = tree.tree_flatten()
return children, (type(tree), aux)
def _generic_unflatten(type_and_aux_data, children):
cls, aux_data = type_and_aux_data
return cls.tree_unflatten(aux_data, children)
# https://stackoverflow.com/questions/18126552/how-to-run-code-when-a-class-is-subclassed
class _AutoRegister(type):
def __init__(cls, *args, **kwargs):
tree_util.register_pytree_node(cls, _generic_flatten, _generic_unflatten)
super().__init__(*args, **kwargs)
class PyTree(metaclass=_AutoRegister):
def tree_flatten(self):
"""Returns aux_data and children."""
raise NotImplemented
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Returns a new PyTree."""
raise NotImplemented
# example user code
class Special(PyTree):
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return "Special(x={}, y={})".format(self.x, self.y)
def tree_flatten(self):
return (None, (self.x, self.y))
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
# example usage
def show_example(structured):
flat, tree = jax.tree_flatten(structured)
unflattened = jax.tree_unflatten(tree, flat)
print("structured={}\n flat={}\n tree={}\n unflattened={}".format(
structured, flat, tree, unflattened))
show_example(Special(1, 2))
# outputs:
# structured=Special(x=1, y=2)
# flat=[1, 2]
# tree=PyTreeDef(<class '__main__.Special'>[(<class '__main__.Special'>, None)], [*,*])
# unflattened=Special(x=1, y=2)
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:5 (5 by maintainers)
Top Results From Across the Web
If you store optional functionality of a base class in a ...
Suppose you have a base class Base which is intended to be subclassed to create more complex objects. But you also have optional...
Read more >Module API — py_trees 2.1.6 documentation - Read the Docs
Bases : object. A parent class for all user definable tree behaviours. ... This behaviour reverse engineers the StatusToBlackboard decorator.
Read more >Array Containers - arraycontext 2021.1 documentation
A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. Parameters:.
Read more >Models - Hugging Face
The base classes PreTrainedModel, TFPreTrainedModel, ... Useful to benchmark the memory footprint of the current model and design some tests.
Read more >Documentation - Decorators - TypeScript
With the introduction of Classes in TypeScript and ES6, there now exist ... for decorators, you must enable the experimentalDecorators compiler option ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I like
register_pytree_node_class
. If it returnscls
, we could use it as a decorator, too, e.g.,It also avoids the API expansion of adding a base class that can be type checked.
I think we should keep
register_pytree_node_class
consistent withregister_pytree_node
, since the former is just a tiny wrapper around the latter. If someone changes the order ofregister_pytree_node
then both will be fixed automatically.I don’t want to keep this issue open, so I’ll make the PR right now. I won’t update the docs, though.