Proposal: mechanism to preserve relative identities for custom Pytrees
See original GitHub issueMotivation
When jax.tree_unflatten
is called new instances of all pytree objects are created and their original identities are lost. While this is expected for basic container types like lists, dicts, and tuples, for other types such as Pytree Modules this can be inconvenient as it makes tasks like parameter sharing difficult. Here is an example that currently doesn’t work of trying to share 2 Child
modules in the same Parent
module:
class Child(Module):
x: jnp.ndarray # assume this is a leaf
...
class Parent(Module):
left: Child # assume these are subtrees
right: Child
...
child = Child(x=jnp.array(1))
parent = Parent(left=child, right=child) # <<<< child is shared
@jax.jit
def f(parent):
assert parent.left is parent.right # Bad
return parent
parent2 = f(parent)
assert parent2.left is parent2.right # Bad
Proposal
Enable preserving the relative identities of custom Pytree classes that opt-in to this behaviour.
By relative identities it means that if two objects in a Pytree have the same identities before flattening, the objects will share the same identity between them after unflattening them, but they won’t have the same identity as their original objects. This means the following assertions are true assuming Module
opted-in to this behavior:
m = Module()
@jax.jit
def f(m1, m2):
assert m1 is m2
return m1, m2
m1, m2 = f(m, m)
assert m1 is m2
assert m is not m1 and m is not m2
Implementation
To achieve this register_pytree_node
could accept an optional preserve_relative_identities: bool
flag (or something like this) that indicates that objects of this class opt-in to preserve their relative identities. When tree_flatten
is called each node’s object id
could be stored in the PyTreeDef
such that when tree_unflatten
is unflattening a node, and that nodes class had preserve_relative_identities=True
, then tree_unflatten
will check if it had already unflattened that element based on the id
and reuse that node if that is the case.
preserve_relative_identities
should also be available for register_pytree_node_class
.
Implications
All current code should run normally, only new code that opt-in to this behavior will use this feature.
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (6 by maintainers)
A terminological point: preserving object identity would mean violating referential transparency, and in particular these would be more like “pydags” than “pytrees”.
I don’t think we want to break referential transparency with existing pytree types. That would prevent us from processing them recursively in a functionally pure way (as @cgarciae already mentioned under “Implementation”, which refers to basically a side-effecting memoization process). Moreover I wouldn’t be surprised if we leverage the referential transparency assumption in lots of different places.
But just for pytree flattening/unflattening alone as in @cgarciae’s most recent comment, the API is already general enough to handle DAGs in your own custom pytree types, so long as you are willing to break referential transparency in your own flattening functions. You just need to do the deduplication-by-python-object-id (and equality-up-to-alpha-renaming) yourself:
I’m not sure of the limitations of this approach. For example, when Tracers are involved, this reliance on Python object identity might lead to surprising results (but maybe it’d be okay to rely on Python object identity just for values which cant be wrapped in Tracers, i.e. in your own custom pytree data types?).
You’re right, I meant to mention that but I neglected to: in general the flattening function would be responsible for flattening its whole subtree, not just flattening one node as usual, by calling into your own set of stateful flatteners! That is, the pytree flattening function for
MyTuple
would recursively call into stateful flatteners for its children, basically the kind of generalized flattener you outlined (or alternatively these could just thread through a reference to mutable object, like a dict, which would be a thread-safe alternative go global state). That way, you could deduplicate within any subtree under one of your pytree classes.My example code did not do recursive DAG flattening. Here’s a version that does! First, a general
PyDag
system (based on the Python pytree implementation in Autodidax):This is probably “hella buggy”, as we’d say where I come from, and I reserve the right to edit this github comment to fix embarrassing mistakes. But it passed literally one example I tried it on, so ship it!
Now, here’s a MyTuple pytree which calls into that dag flattening (i.e. interfaces pydags with the existing pytree system):
My main point is just that I think with the existing pytree system you can at least flatten subtrees of your custom pytree types however you’d like, including as dags-by-objectid. Maybe that can unblock you!
Of course, we could also consider building some pydag behavior into JAX, which would let us deduplicate across all argument lists (even if the top-level container is not a custom pytree type you control). It’s worth considering! Like I said before, I’m a bit wary of where we might leverage the referential transparency assumption. But maybe it’d all work out… experimenting with the above pydag approach might help us learn things!
WDYT? Does this approach unblock you, without needing JAX-internal changes?