question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Proposal: mechanism to preserve relative identities for custom Pytrees

See original GitHub issue

Motivation

When jax.tree_unflatten is called new instances of all pytree objects are created and their original identities are lost. While this is expected for basic container types like lists, dicts, and tuples, for other types such as Pytree Modules this can be inconvenient as it makes tasks like parameter sharing difficult. Here is an example that currently doesn’t work of trying to share 2 Child modules in the same Parent module:

class Child(Module):
    x: jnp.ndarray # assume this is a leaf
   ...

class Parent(Module):
    left: Child # assume these are subtrees
    right: Child 
    ...

child = Child(x=jnp.array(1))
parent = Parent(left=child, right=child)  # <<<< child is shared

@jax.jit
def f(parent):
    assert parent.left is parent.right  # Bad
    return parent

parent2 = f(parent)
assert parent2.left is parent2.right  # Bad

Proposal

Enable preserving the relative identities of custom Pytree classes that opt-in to this behaviour.

By relative identities it means that if two objects in a Pytree have the same identities before flattening, the objects will share the same identity between them after unflattening them, but they won’t have the same identity as their original objects. This means the following assertions are true assuming Module opted-in to this behavior:

m = Module()

@jax.jit
def f(m1, m2):
    assert m1 is m2 
    return m1, m2

m1, m2 = f(m, m)

assert m1 is m2
assert m is not m1 and m is not m2

Implementation

To achieve this register_pytree_node could accept an optional preserve_relative_identities: bool flag (or something like this) that indicates that objects of this class opt-in to preserve their relative identities. When tree_flatten is called each node’s object id could be stored in the PyTreeDef such that when tree_unflatten is unflattening a node, and that nodes class had preserve_relative_identities=True, then tree_unflatten will check if it had already unflattened that element based on the id and reuse that node if that is the case.

preserve_relative_identities should also be available for register_pytree_node_class.

Implications

All current code should run normally, only new code that opt-in to this behavior will use this feature.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:11 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Sep 16, 2021

A terminological point: preserving object identity would mean violating referential transparency, and in particular these would be more like “pydags” than “pytrees”.

I don’t think we want to break referential transparency with existing pytree types. That would prevent us from processing them recursively in a functionally pure way (as @cgarciae already mentioned under “Implementation”, which refers to basically a side-effecting memoization process). Moreover I wouldn’t be surprised if we leverage the referential transparency assumption in lots of different places.

But just for pytree flattening/unflattening alone as in @cgarciae’s most recent comment, the API is already general enough to handle DAGs in your own custom pytree types, so long as you are willing to break referential transparency in your own flattening functions. You just need to do the deduplication-by-python-object-id (and equality-up-to-alpha-renaming) yourself:

from typing import Any, NamedTuple
import itertools as it
from collections import defaultdict

import jax
from jax.util import unzip2
from jax.tree_util import register_pytree_node

class MyTuple:
  elts: tuple[Any]
  def __init__(self, *elts):
    self.elts = elts
  def __iter__(self):
    return iter(self.elts)

def flatten_mytuple(x):
  counts = it.count()
  id_to_name = defaultdict(lambda: next(counts))
  name_list = [id_to_name[id(e)] for e in x.elts]
  uniques = {id(e): e for e in x.elts}
  unique_names, unique_vals = unzip2((id_to_name[i], v) for i, v in uniques.items())
  return unique_vals, (unique_names, name_list)

def unflatten_mytuple(aux, unique_vals):
  unique_names, name_list = aux
  uniques = dict(zip(unique_names, unique_vals))
  elts = [uniques[name] for name in name_list]
  return MyTuple(*elts)

register_pytree_node(MyTuple, flatten_mytuple, unflatten_mytuple)

class Module(NamedTuple): pass  # added this as trivial pytree

###

m = Module()
tree = MyTuple(m, m)  # NOTE: changed from Python tuple to MyTuple!

leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)

m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)

assert m11 is m12
assert m21 is m22
assert m11 is not m21 and m12 is not m22
assert m11 is not m and m12 is not m

I’m not sure of the limitations of this approach. For example, when Tracers are involved, this reliance on Python object identity might lead to surprising results (but maybe it’d be okay to rely on Python object identity just for values which cant be wrapped in Tracers, i.e. in your own custom pytree data types?).

1reaction
mattjjcommented, Sep 17, 2021

I think the main problem with your suggestion is that when you have nested structures like this […] local information is not enough because shared here is in separate branches, to solve this you need to keep track of all objects within the pytree during unflatten.

You’re right, I meant to mention that but I neglected to: in general the flattening function would be responsible for flattening its whole subtree, not just flattening one node as usual, by calling into your own set of stateful flatteners! That is, the pytree flattening function for MyTuple would recursively call into stateful flatteners for its children, basically the kind of generalized flattener you outlined (or alternatively these could just thread through a reference to mutable object, like a dict, which would be a thread-safe alternative go global state). That way, you could deduplicate within any subtree under one of your pytree classes.

My example code did not do recursive DAG flattening. Here’s a version that does! First, a general PyDag system (based on the Python pytree implementation in Autodidax):

from functools import partial
import itertools as it
from collections import defaultdict
from typing import (Callable, Type, Hashable, Dict, Any, NamedTuple, Tuple,
                    Sequence, List, Union)
from jax.util import unzip2

Name = int
Names = Sequence[int]
AuxData = Any

class NodeType(NamedTuple):
  name: str
  to_iterable: Callable[[Callable, Any], Tuple[AuxData, Names]]
  from_iterable: Callable[[Callable, AuxData, Names], Any]
  def __repr__(self): return f'DagNode[{self.name}]'

def register_pydag_node(ty: Type, to_iter: Callable, from_iter: Callable)-> None:
  node_types[ty] = NodeType(str(ty), to_iter, from_iter)

node_types: Dict[Type, NodeType] = {}
register_pydag_node(tuple,
                    lambda f, t: (None, [f(e) for e in t]),
                    lambda u, _, names: tuple(u(n) for n in names))
register_pydag_node(list,
                    lambda f, l: (None, [f(e) for e in l]),
                    lambda u, _, names: [u(n) for n in names])
register_pydag_node(dict,
                    lambda f, d: (tuple(sorted(d)), [f(d[k]) for k in sorted(d)]),
                    lambda u, keys, names: {k: u(n) for k, n in zip(keys, names)})

class PyDagNode(NamedTuple):
  node_type: NodeType
  node_auxdata: Hashable
  names: Names

class PyDagLeaf(NamedTuple):
  name: Name

PyDagDef = Union[PyDagNode, PyDagLeaf]

class FlattenState(NamedTuple):
  id_to_name: Dict[int, Name]
  name_to_obj: Dict[Name, Any]

def dag_flatten(x: Any) -> Tuple[List[Any], Tuple[List[Name], PyDagDef]]:
  names = it.count()
  state = FlattenState(defaultdict(lambda: next(names)), dict())
  dagdef = _dag_flatten(state, x)
  unique_names, unique_vals = unzip2(state.name_to_obj.items())
  return unique_vals, (unique_names, dagdef)

def _dag_flatten(state: FlattenState, x: Any) -> PyDagDef:
  node_type = node_types.get(type(x))
  if node_type:
    node_auxdata, names = node_type.to_iterable(partial(_dag_flatten, state), x)
    return PyDagNode(node_type, node_auxdata, names)
  else:
    name = state.id_to_name[id(x)]
    state.name_to_obj[name] = x
    return PyDagLeaf(name)

class UnflattenState(NamedTuple):
  name_to_obj: Dict[Name, Any]

def dag_unflatten(dagspec: Tuple[Names, PyDagDef], xs: List[Any]) -> Any:
  names, dagdef = dagspec
  state = UnflattenState(dict(zip(names, xs)))
  return _dag_unflatten(state, dagdef)

def _dag_unflatten(state: UnflattenState, dagdef: PyDagDef) -> Any:
  if type(dagdef) is PyDagLeaf:
    return state.name_to_obj[dagdef.name]
  else:
    u = partial(_dag_unflatten, state)
    return dagdef.node_type.from_iterable(u, dagdef.node_auxdata, dagdef.names)

This is probably “hella buggy”, as we’d say where I come from, and I reserve the right to edit this github comment to fix embarrassing mistakes. But it passed literally one example I tried it on, so ship it!

Now, here’s a MyTuple pytree which calls into that dag flattening (i.e. interfaces pydags with the existing pytree system):

class MyTuple:
  elts: tuple[Any]
  def __init__(self, *elts):
    self.elts = elts
  def __iter__(self):
    return iter(self.elts)

# register with our pydag system
register_pydag_node(MyTuple,
                    lambda f, t: (None, [f(e) for e in t]),
                    lambda u, _, names: MyTuple(*[u(n) for n in names]))

# register as a pytree with jax, but tell it to flatten like a dag
from jax.tree_util import register_pytree_node
register_pytree_node(MyTuple, dag_flatten, dag_unflatten)


###

import jax

class Module(NamedTuple): pass  # added this as trivial pytree

# Test 1

m = Module()
tree = MyTuple(m, m)

leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)

m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)

assert m11 is m12
assert m21 is m22
assert m11 is not m21 and m12 is not m22
assert m11 is not m and m12 is not m

# Test 2

m = Module()
tree = MyTuple(MyTuple(m, m), MyTuple(m, m))
leaves, treedef = jax.tree_flatten(tree)
((m11, m12), (m21, m22)) = jax.tree_unflatten(treedef, leaves)
assert m11 is m12 is m21 is m22

My main point is just that I think with the existing pytree system you can at least flatten subtrees of your custom pytree types however you’d like, including as dags-by-objectid. Maybe that can unblock you!

Of course, we could also consider building some pydag behavior into JAX, which would let us deduplicate across all argument lists (even if the top-level container is not a custom pytree type you control). It’s worth considering! Like I said before, I’m a bit wary of where we might leverage the referential transparency assumption. But maybe it’d all work out… experimenting with the above pydag approach might help us learn things!

WDYT? Does this approach unblock you, without needing JAX-internal changes?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Pytrees - JAX documentation - Read the Docs
In JAX, we use the term pytree to refer to a tree-like structure built out of container-like Python objects. Classes are considered container-like...
Read more >
arXiv:2112.10526v2 [quant-ph] 18 Aug 2022
Moreover, integration of our quantum object primitives with the jax ecosystem allows users to easily define custom neural-network architectures ...
Read more >
ACAT 2022
Ionization of matters by charged particles are the main mechanism for ... The project plans to be proposed for the Event Filter TDAQ...
Read more >
Machine learning toolbox for many-body quantum systems
4.4 Implementing custom algorithms using NETKET ... A more advanced feature is an extension mechanism built around multiple dispatch [38],.
Read more >
NumPyro Documentation
MCMC inference API, effect handlers and writing custom inference utilities. ... e.g. in NumPyro, there is no global parameter store or random state, ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found