question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Prevent unnecessary recompilation due to closures in optimizers

See original GitHub issue

Now that Jax supports dataclasses as PyTrees, would it be possible to switch to using them instead of namedtuple? The benefits are explained here.

The biggest benefit would be preventing unnecessary recompilation. The current Optax code uses closures, which will cause Jax to unnecessarily recompile a jitted function that accepts a GradientTransformation. (The closures are different objects that hash differently, which means that changing the parameters to the GradientTransformation must cause the jitted function to recompile.)

A dataclass version of Optax would look something like this.

I am happy to do submit a pull request if this change is okay.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:36 (16 by maintainers)

github_iconTop GitHub Comments

3reactions
rosshemsleycommented, Nov 24, 2021

(Just writing to reassure you that we have not forgotten about this, it’s just a busy week over here 😃 )

1reaction
rosshemsleycommented, Nov 30, 2021

Hey @NeilGirdhar,

We have discussed these proposals extensively within the optax maintainers team, and gathered feedback from a number of users and developers. We believe that the approach outlined below would be the best way forward,

The request here is for JAX to treat instances of optax.GradientTransformation as pytrees. In @NeilGirdhar’s proposal, this means combining optimizer hyperparameters with the init / update functions into a single dataclass. In @cgarciae’s proposal the state is additionally also included in the dataclass.

Making gradient transformations into pytree classes provides the ability to pass these objects to JAX jitted functions as dynamic rather than static arguments. On balance, the idea is reasonable and might be easier to understand for users who like OO patterns and the specific dataclass-based strand of OO programming that is being used in some JAX codebases. There are also downsides to this approach though - notably, the current optax API has a very sharp boundary between “functions implementing user behavior” (the init / update methods) and “data containers’’ (i.e. the state). In the above proposals, hyperparameters and/or state are “folded into” the optimizer class, and the pure functions become methods on a mutable class instead. This is not a problem per se (it’s a common pattern in OO!), but it is different from the programming patterns of many of our users.

So what is the way forward? As Optax developers, we do care a lot about users being able to extend and modify the optax components. The whole design focuses on composability and extensibility. We believe the best way forwards here is thus not to use dataclasses in each individual component, but rather to create wrappers that turn optax optimisers into dataclasses with the desired behavior.

This is the approach taken by @n2cholas in his inject_hyperparams (already upstreamed into optax) and is the approach taken by @cgarciae (in Treex). Even when living in libraries outside of optax, these wrappers are lighter to maintain than the current approach in tjax: for instance in @cgarciae’s Treex only one wrapper is maintained, and all optax transformations are exposed after wrapping them programmatically (instead of defining new separate classes for each optax component as in tjax, where new code to be written every time components are added to optax).

Our suggestion is thus to proceed in two stages In an initial stage tjax could define its own wrapper (similarly to Treex). Once a common dataclass implementation is provided by JAX itself we would be happy to see extensions of this kind upstreamed into optax (as we were to see inject_hyperparams).

E.g. we could imagine exposing a triplet of extensions: inject_hyperparams (@n2cholas’s wrapper) wrap_transformation_as_pytree (implementing @NeilGirdhar’s extension) wrap_transformation_as_pytree_with_state (implementing @cgarciae’s extension)

We hope the above is acceptable, and while this might not be everybody’s favorite solution, we want to stress that we really appreciate everybody’s feedback and contributions, and we hope to continue seeing you all engaged in the development of optax and contributing to the library.

Read more comments on GitHub >

github_iconTop Results From Across the Web

SQL Server Stored Procedure Recompilation Factors
This article will show which factors lead to recompile to SQL Server stored procedures.
Read more >
Advanced Compilation | Closure Compiler - Google Developers
In other words, you can prevent unwanted code removal by including your program's entry point in the code that you pass to Closure...
Read more >
Compile procedure automatically -- how to avoid cascading in...
If procedure A calls procedure B and B is "changed", then A is directly affected and must be recompiled. There is no "auto...
Read more >
Parameter Sniffing, Embedding, and the RECOMPILE Options
Query parameterization promotes the reuse of cached execution plans, thereby avoiding unnecessary compilations, and reducing the number of ...
Read more >
Optimizing Code — Emscripten 3.1.26-git (dev) documentation
JavaScript is generated at this phase, and is optimized by Emscripten's JS optimizer. Optionally you can also run the closure compiler, which is...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found