question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

pytree transformation

See original GitHub issue

I’d like to write code like np.sum((x - y) ** 2) that automatically works on arbitrarily nested Python objects.

I think this could be done cleanly with a “pytree transformation” that acts on functions, somewhat similar to the existing vmap, e.g.,

@jax.autotree
def norm(x, y):
    return np.sum((x - y) ** 2)

Under the hood, we would need to define what “autotree” versions of each primitive function look like on a list of heterogeneously shaped arrays, e.g.,

  • Vectorized functions like + operate over each array independently.
  • In mixed tree/array operations, the array should get “broadcast” by repeating it when the structure of the tree, e.g., {'a': x, 'b': y'} + z -> {'a': x + z, 'b': y + z}.
  • Reductions like sum sum over all the arrays to produce a scalar, unless an axis is specified.

The upside is that this could make it trivial to write functional transformations like jax.experimental.optimizers in a fully general way. I can think of lots of nice applications for this, e.g., for methods that solve equations.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:10
  • Comments:12 (12 by maintainers)

github_iconTop GitHub Comments

3reactions
shoyercommented, Jun 3, 2021

I need to figure out how to get #3263 moving towards something we can merge. The main unresolved challenge to make this useful is figuring out how to get JAX transformations and control flow inside tree vectorized functions to work. There is a lot of reorganizing data for new calling conventions that gets rather tedious, but maybe we can just power through it or figure out a better way, e.g., perhaps with partial tree flattening using the new(ish) is_leaf parameter.

With regard to in_trees/out_trees, that seems like a very sensible suggestion (similar to in_axes/out_axes for vmap), although you can get a similar effect to in_trees just by defining an inner function to apply tree_vectorize to using a closure. I think this is somewhat more of a superficial issue, though – the key challenge to tackle first is figuring out how to get the fundamental JAX transformations inside tree_vectorize working.

CC @apaszke

2reactions
shoyercommented, Nov 10, 2021

Update: I have a new approach for implementing this, which I think is much more promising than #3263:

Feedback would be very welcome!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Pytrees - JAX documentation - Read the Docs
JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays. Applying optional parameters to...
Read more >
TF_JAX_tutorials - Part 10 (Pytrees in JAX) - Kaggle
We can combine any JAX transformation and apply it to the pytrees. Some transformations like vmap and pmap take in optional parameters like...
Read more >
Equinox: neural networks in JAX via callable PyTrees ... - arXiv
Two: we filter a PyTree to isolate just those components that should be treated when transforming ('jit', 'grad' or 'vmap'-ing) a higher-order.
Read more >
Patrick Kidger (fosstodon.org/@PatrickKidger) on Twitter: "This ...
The leaves of the PyTree can be arbitrary Python objects. Pretty much everything in ... Equinox is all just regular JAX -- PyTrees...
Read more >
mesh transformation module — Transform 3.0 documentation
It works on arrays (as defined in Converter documentation) or on CGNS/Python trees (pyTrees), if they provide grid coordinates. In the pyTree version,...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found