Getting all values in a pytree
See original GitHub issueHi, sorry if this is a question with an obvious answer, but how do I get all the values in a pytree?
I tried using tree_flatten
, but it gives me the device arrays, not the values they contain
(For extra context I want to implement regularizers using jax and need to sum all the squares/absolute values of the neural network parameters)
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (5 by maintainers)
Top Results From Across the Web
Working with Pytrees - JAX documentation - Read the Docs
a pytree is a container of leaf elements and/or more pytrees. ... When using multiple arguments with jax.tree_map , the structure of the...
Read more >TF_JAX_tutorials - Part 10 (Pytrees in JAX) - Kaggle
So, don't use the same object in multiple leaves of a pytree. We can flatten the tree at each level, get the leaves,...
Read more >How to get keys for jax.tree_flatten object? - Stack Overflow
There is no mechanism for this built in to jax.tree_util . In a way, the question is ill-posed: tree flattening is applicable to...
Read more >All of Equinox - Patrick Kidger
All of Equinox¤. Equinox is a very small and easy to understand library. (Because it uses JAX abstractions like PyTrees, it doesn't need...
Read more >Patrick Kidger (fosstodon.org/@PatrickKidger) on Twitter: "This ...
You can represent your whole model as a single PyTree -- not just its parameters, but everything else as well! Then filter out...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
It does still exist:
It looks like it’s not listed in JAX’s documentation
It works to ravel things, but that might actually be less efficient because it probably involves concatenating all the values into a single contiguous buffer. It might be better to do something like
(where
sum
is just Python’s builtin.)