Automatically treat dataclasses as pytrees
See original GitHub issueJAX should automatically treat dataclasses as pytrees, so they don’t have to be explicitly registered.
Ideally we would also support some syntax for non-differentiable parameters. Flax does so by adding custom metadata into dataclassess.Field.metadata
with the special flax.struct.field
constructor, which seems like a very clean way to do this.
I started working on this in a branch, but haven’t tested anything so it very likely is entirely broken/non-functional! If somebody wants to finish this up it would be awesome 😃 https://github.com/google/jax/compare/master...shoyer:dataclasses-pytree
Issue Analytics
- State:
- Created 4 years ago
- Reactions:15
- Comments:21 (14 by maintainers)
Top Results From Across the Web
jax-dataclasses - PyPI
In a pytree node, static fields will be treated as part of the treedef ... All dataclasses are automatically marked as frozen and...
Read more >The Module lifecycle - Flax - Read the Docs
These annotations automatically define a constructor. ... Flax allows to define dataclasses which are Pytree compatible using the flax.struct API.
Read more >Treeo
Compatible: Treeo Tree objects are compatible with any jax function that accepts Pytrees. Standards-based: treeo.field is built on top of python's dataclasses.
Read more >Equinox: neural networks in JAX via callable PyTrees ... - arXiv
Two: we filter a PyTree to isolate just those components that should be treated when transforming ('jit', 'grad' or 'vmap'-ing) a higher-order.
Read more >Mailman 3 - Typing-sig - python.org
A more generic solution to the @dataclass problem? by Buck Evan 05 Aug '22. 05 Aug '22. When I first stubbed my toe...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hello all!
@tomhennigan, @NeilGirdhar thanks for your input.
Can the same reasoning be applied to
lists
,tuples
andnamedtuples
. Often a user does not want to differentiate through these structures as well.Agreed and I think the same option should be available for other structures. Actually, this is another important missing bit in JAX! I raised that issue before https://github.com/google/jax/issues/2588.
Example: non-parameteric Gaussian process model with trainable hyper-parameters lengthscale and variance of the squared exponential kernel. Often, there is a need to experiment with the model in a way that we compute gradients w.r.t. to only variance, only lengthscale or both variance and lengthscale. I had to write different code for each case specifically, which is super annoying considering that other frameworks (TF and PyTorch) support trainability of tensors out of the box.
Long story short: as I think, two features will bring more users to JAX:
What are the next steps? @shoyer, @mattjj
named tuples are
Iterable
and are a subclass of tuple. One could argue it’s an implementation detail that PyTree’s handling of tuple works only on the tuple type and not on any subclasses. Not handling NamedTuples as tuples would create an inconsistency with how most api’s deals with tuples andIterable
because those will handle NamedTuples as ordinary tuples. But I think the real problem with NamedTuple is that inheriting from tuple is more a buggy side-effect than a feature. With dataclass we have a better alternative and most use cases of NamedTuple should disappear except maybe when there is truly an order in the fields (like a NamedSequential structure for example). But I guess deregistering could help push adoption the right pattern.I think registering dataclasses by default (with a flag or always) does lead to problems. It might be worth considering allowing dataclasses to be registered post-hoc e.g.:
Having a field decorator for defining metadata can be really useful at times. In JAX it often avoids the need to define functions and flag like arguments in static_argnums.
A non-registered dataclass can also really help. One thing that comes to mind is configs. They often contain bools, integers, and floats so they can often be raised into jax but they are intended to be constants and allow for optimisations like removing dead branches and avoiding removing multiply by 0 (for example
weight_decay=0
).