question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Automatically treat dataclasses as pytrees

See original GitHub issue

JAX should automatically treat dataclasses as pytrees, so they don’t have to be explicitly registered.

Ideally we would also support some syntax for non-differentiable parameters. Flax does so by adding custom metadata into dataclassess.Field.metadata with the special flax.struct.field constructor, which seems like a very clean way to do this.

I started working on this in a branch, but haven’t tested anything so it very likely is entirely broken/non-functional! If somebody wants to finish this up it would be awesome 😃 https://github.com/google/jax/compare/master...shoyer:dataclasses-pytree

xref https://github.com/google/jax/issues/1808

Issue Analytics

  • State:open
  • Created 4 years ago
  • Reactions:15
  • Comments:21 (14 by maintainers)

github_iconTop GitHub Comments

3reactions
awavcommented, Jun 6, 2020

Hello all!

@tomhennigan, @NeilGirdhar thanks for your input.

For context in the various other tree libraries (tf.nest and dm-tree) we have pushed back on dataclasses automatically being treated as nests because for these “struct” types it is not clear if this (treating these as containers) is intended in all cases

Can the same reasoning be applied to lists, tuples and namedtuples. Often a user does not want to differentiate through these structures as well.

Instead of treeating all dataclasses as jaxtrees, could we instead create a drop in replacement for dataclass for users who know they want this behavior?

Agreed and I think the same option should be available for other structures. Actually, this is another important missing bit in JAX! I raised that issue before https://github.com/google/jax/issues/2588.

Example: non-parameteric Gaussian process model with trainable hyper-parameters lengthscale and variance of the squared exponential kernel. Often, there is a need to experiment with the model in a way that we compute gradients w.r.t. to only variance, only lengthscale or both variance and lengthscale. I had to write different code for each case specifically, which is super annoying considering that other frameworks (TF and PyTorch) support trainability of tensors out of the box.

Long story short: as I think, two features will bring more users to JAX:

  • Dataclasses support for organising complex differentiable structures.
  • Converter of differentiable objects to non-differentiable.

What are the next steps? @shoyer, @mattjj

2reactions
jheekcommented, Mar 25, 2021

I’d still like to know why this is acceptable for namedtuple but not dataclasses

named tuples are Iterable and are a subclass of tuple. One could argue it’s an implementation detail that PyTree’s handling of tuple works only on the tuple type and not on any subclasses. Not handling NamedTuples as tuples would create an inconsistency with how most api’s deals with tuples and Iterable because those will handle NamedTuples as ordinary tuples. But I think the real problem with NamedTuple is that inheriting from tuple is more a buggy side-effect than a feature. With dataclass we have a better alternative and most use cases of NamedTuple should disappear except maybe when there is truly an order in the fields (like a NamedSequential structure for example). But I guess deregistering could help push adoption the right pattern.

I think registering dataclasses by default (with a flag or always) does lead to problems. It might be worth considering allowing dataclasses to be registered post-hoc e.g.:

@jax.tree_util.dataclass
class Foo:

# equivalent too
@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class Foo:

Having a field decorator for defining metadata can be really useful at times. In JAX it often avoids the need to define functions and flag like arguments in static_argnums.

A non-registered dataclass can also really help. One thing that comes to mind is configs. They often contain bools, integers, and floats so they can often be raised into jax but they are intended to be constants and allow for optimisations like removing dead branches and avoiding removing multiply by 0 (for example weight_decay=0).

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax-dataclasses - PyPI
In a pytree node, static fields will be treated as part of the treedef ... All dataclasses are automatically marked as frozen and...
Read more >
The Module lifecycle - Flax - Read the Docs
These annotations automatically define a constructor. ... Flax allows to define dataclasses which are Pytree compatible using the flax.struct API.
Read more >
Treeo
Compatible: Treeo Tree objects are compatible with any jax function that accepts Pytrees. Standards-based: treeo.field is built on top of python's dataclasses.
Read more >
Equinox: neural networks in JAX via callable PyTrees ... - arXiv
Two: we filter a PyTree to isolate just those components that should be treated when transforming ('jit', 'grad' or 'vmap'-ing) a higher-order.
Read more >
Mailman 3 - Typing-sig - python.org
A more generic solution to the @dataclass problem? by Buck Evan 05 Aug '22. 05 Aug '22. When I first stubbed my toe...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found