question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Optional base class or class decorator for PyTree objects

See original GitHub issue

Instead of requiring users to explicitly call register_pytree_node(), we could supply a base class that lets them instead define a pair of special methods for flattening/unflattening.

This looks slightly cleaner than calling functions on classes, but still avoids the evils of implementation inheritance.

Here’s a working example, adapted from the pytree docs:

import jax
from jax import tree_util

# baseclass that should go in tree_util.py

def _generic_flatten(tree):
  aux, children = tree.tree_flatten()
  return children, (type(tree), aux)

def _generic_unflatten(type_and_aux_data, children):
  cls, aux_data = type_and_aux_data
  return cls.tree_unflatten(aux_data, children)

# https://stackoverflow.com/questions/18126552/how-to-run-code-when-a-class-is-subclassed
class _AutoRegister(type):
  def __init__(cls, *args, **kwargs):
    tree_util.register_pytree_node(cls, _generic_flatten, _generic_unflatten)
    super().__init__(*args, **kwargs)

class PyTree(metaclass=_AutoRegister):
  def tree_flatten(self):
    """Returns aux_data and children."""
    raise NotImplemented

  @classmethod
  def tree_unflatten(cls, aux_data, children):
    """Returns a new PyTree."""
    raise NotImplemented

# example user code

class Special(PyTree):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)

  def tree_flatten(self):
    return (None, (self.x, self.y))

  @classmethod
  def tree_unflatten(cls, aux_data, children):
    return cls(*children)

# example usage

def show_example(structured):
  flat, tree = jax.tree_flatten(structured)
  unflattened = jax.tree_unflatten(tree, flat)
  print("structured={}\n  flat={}\n  tree={}\n  unflattened={}".format(
      structured, flat, tree, unflattened))

show_example(Special(1, 2))
# outputs:
# structured=Special(x=1, y=2)
#   flat=[1, 2]
#   tree=PyTreeDef(<class '__main__.Special'>[(<class '__main__.Special'>, None)], [*,*])
#   unflattened=Special(x=1, y=2)

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
shoyercommented, Mar 10, 2020

I like register_pytree_node_class. If it returns cls, we could use it as a decorator, too, e.g.,

@register_pytree_node_class
class Special:
  ...

It also avoids the API expansion of adding a base class that can be type checked.

0reactions
mattjjcommented, Mar 10, 2020

I think we should keep register_pytree_node_class consistent with register_pytree_node, since the former is just a tiny wrapper around the latter. If someone changes the order of register_pytree_node then both will be fixed automatically.

I don’t want to keep this issue open, so I’ll make the PR right now. I won’t update the docs, though.

Read more comments on GitHub >

github_iconTop Results From Across the Web

If you store optional functionality of a base class in a ...
Suppose you have a base class Base which is intended to be subclassed to create more complex objects. But you also have optional...
Read more >
Module API — py_trees 2.1.6 documentation - Read the Docs
Bases : object. A parent class for all user definable tree behaviours. ... This behaviour reverse engineers the StatusToBlackboard decorator.
Read more >
Array Containers - arraycontext 2021.1 documentation
A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. Parameters:.
Read more >
Models - Hugging Face
The base classes PreTrainedModel, TFPreTrainedModel, ... Useful to benchmark the memory footprint of the current model and design some tests.
Read more >
Documentation - Decorators - TypeScript
With the introduction of Classes in TypeScript and ES6, there now exist ... for decorators, you must enable the experimentalDecorators compiler option ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found