pytree transformation
See original GitHub issueI’d like to write code like np.sum((x - y) ** 2)
that automatically works on arbitrarily nested Python objects.
I think this could be done cleanly with a “pytree transformation” that acts on functions, somewhat similar to the existing vmap
, e.g.,
@jax.autotree
def norm(x, y):
return np.sum((x - y) ** 2)
Under the hood, we would need to define what “autotree” versions of each primitive function look like on a list of heterogeneously shaped arrays, e.g.,
- Vectorized functions like
+
operate over each array independently. - In mixed tree/array operations, the array should get “broadcast” by repeating it when the structure of the tree, e.g.,
{'a': x, 'b': y'} + z
->{'a': x + z, 'b': y + z}
. - Reductions like
sum
sum over all the arrays to produce a scalar, unless an axis is specified.
The upside is that this could make it trivial to write functional transformations like jax.experimental.optimizers
in a fully general way. I can think of lots of nice applications for this, e.g., for methods that solve equations.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:10
- Comments:12 (12 by maintainers)
Top Results From Across the Web
Pytrees - JAX documentation - Read the Docs
JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays. Applying optional parameters to...
Read more >TF_JAX_tutorials - Part 10 (Pytrees in JAX) - Kaggle
We can combine any JAX transformation and apply it to the pytrees. Some transformations like vmap and pmap take in optional parameters like...
Read more >Equinox: neural networks in JAX via callable PyTrees ... - arXiv
Two: we filter a PyTree to isolate just those components that should be treated when transforming ('jit', 'grad' or 'vmap'-ing) a higher-order.
Read more >Patrick Kidger (fosstodon.org/@PatrickKidger) on Twitter: "This ...
The leaves of the PyTree can be arbitrary Python objects. Pretty much everything in ... Equinox is all just regular JAX -- PyTrees...
Read more >mesh transformation module — Transform 3.0 documentation
It works on arrays (as defined in Converter documentation) or on CGNS/Python trees (pyTrees), if they provide grid coordinates. In the pyTree version,...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I need to figure out how to get #3263 moving towards something we can merge. The main unresolved challenge to make this useful is figuring out how to get JAX transformations and control flow inside tree vectorized functions to work. There is a lot of reorganizing data for new calling conventions that gets rather tedious, but maybe we can just power through it or figure out a better way, e.g., perhaps with partial tree flattening using the new(ish)
is_leaf
parameter.With regard to
in_trees
/out_trees
, that seems like a very sensible suggestion (similar toin_axes
/out_axes
forvmap
), although you can get a similar effect toin_trees
just by defining an inner function to applytree_vectorize
to using a closure. I think this is somewhat more of a superficial issue, though – the key challenge to tackle first is figuring out how to get the fundamental JAX transformations insidetree_vectorize
working.CC @apaszke
Update: I have a new approach for implementing this, which I think is much more promising than #3263:
Feedback would be very welcome!