question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Getting all values in a pytree

See original GitHub issue

Hi, sorry if this is a question with an obvious answer, but how do I get all the values in a pytree? I tried using tree_flatten, but it gives me the device arrays, not the values they contain (For extra context I want to implement regularizers using jax and need to sum all the squares/absolute values of the neural network parameters)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (5 by maintainers)

github_iconTop GitHub Comments

3reactions
jakevdpcommented, Aug 3, 2021

It does still exist:

from jax.flatten_util import ravel_pytree

It looks like it’s not listed in JAX’s documentation

3reactions
mattjjcommented, May 17, 2020

It works to ravel things, but that might actually be less efficient because it probably involves concatenating all the values into a single contiguous buffer. It might be better to do something like

def l2_normsq(x):
  leaves, _ = tree_util.tree_flatten(x)
  return sum([np.sum(leaf ** 2) for leaf in leaves])

(where sum is just Python’s builtin.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Working with Pytrees - JAX documentation - Read the Docs
a pytree is a container of leaf elements and/or more pytrees. ... When using multiple arguments with jax.tree_map , the structure of the...
Read more >
TF_JAX_tutorials - Part 10 (Pytrees in JAX) - Kaggle
So, don't use the same object in multiple leaves of a pytree. We can flatten the tree at each level, get the leaves,...
Read more >
How to get keys for jax.tree_flatten object? - Stack Overflow
There is no mechanism for this built in to jax.tree_util . In a way, the question is ill-posed: tree flattening is applicable to...
Read more >
All of Equinox - Patrick Kidger
All of Equinox¤. Equinox is a very small and easy to understand library. (Because it uses JAX abstractions like PyTrees, it doesn't need...
Read more >
Patrick Kidger (fosstodon.org/@PatrickKidger) on Twitter: "This ...
You can represent your whole model as a single PyTree -- not just its parameters, but everything else as well! Then filter out...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found