It's very difficult to write libraries that support both Haiku and plain Jax
See original GitHub issueI would like to extend my fixed point solver to work in both Haiku and Jax. I thought it would be as simple as replacing jax.lax.scan
with hk.scan
, etc. Unfortunately, I get
ValueError: hk.jit() should not be used outside of hk.transform. Use jax.jit() instead.
Perhaps I’m missing something, but would it be possible to reverse the design decision (https://github.com/deepmind/dm-haiku/pull/17) to raise an error and instead simply fall back to the Jax version of the command if the stateful context isn’t needed?
Also, just out of curiosity, but is jacfwd
broken in Haiku, or does it not need a stateful wrapper?
@trevorcai WDYT?
Issue Analytics
- State:
- Created 3 years ago
- Comments:14 (7 by maintainers)
Top Results From Across the Web
[D] What JAX NN library to use? : r/MachineLearning - Reddit
Well, it's both a lot easier to use and more general than any of the other frameworks. There was a reason I wrote...
Read more >Haiku Contributors
Haiku is a library built on top of JAX designed to provide simple, composable abstractions for machine learning research. import haiku as hk....
Read more >dm-haiku - PyPI
JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support. Haiku is a simple neural network library ......
Read more >An Interview with Charlotte Digregorio - Evanston Public Library
Haiku is challenging for me and for most haikuists in that it is difficult to say a lot in so few words. After...
Read more >My Word, it's 2019 | Plain and Fancy Girl - Marian Beaman
writing a book is more difficult than law school, running a marathon, or climbing a mountain!” Allison Leotta, author. She should know, she...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
We decided to not merge https://github.com/google/jax/pull/4117.
But let me share how we’ve solved this problem in our own codebase, using our own versions of higher order functions like
scan
.First, we define a version of
scan
that does the right thing for initialization:Then in Haiku, you can write something like:
Presumably this sort of thing could be done for most/all higher order functions in JAX. It doesn’t even have to be library specific, so I can imagine this being a good fit for a third-party library or perhaps even a
jax.experimental
module (but not builtin to JAX core).Yes, I’m happy to revive google/jax#4117 😃
My original motivation was actually exactly this issue: we have code that we want to support both Haiku and JAX. So far we’ve gotten around this by writing our own stateful version of higher order functions like
jit
for switching back and forth, but that isn’t very extensible.