question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

It's very difficult to write libraries that support both Haiku and plain Jax

See original GitHub issue

I would like to extend my fixed point solver to work in both Haiku and Jax. I thought it would be as simple as replacing jax.lax.scan with hk.scan, etc. Unfortunately, I get

 ValueError: hk.jit() should not be used outside of hk.transform. Use jax.jit() instead.

Perhaps I’m missing something, but would it be possible to reverse the design decision (https://github.com/deepmind/dm-haiku/pull/17) to raise an error and instead simply fall back to the Jax version of the command if the stateful context isn’t needed?

Also, just out of curiosity, but is jacfwd broken in Haiku, or does it not need a stateful wrapper?

@trevorcai WDYT?

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:14 (7 by maintainers)

github_iconTop GitHub Comments

2reactions
shoyercommented, Feb 26, 2021

We decided to not merge https://github.com/google/jax/pull/4117.

But let me share how we’ve solved this problem in our own codebase, using our own versions of higher order functions like scan.

First, we define a version of scan that does the right thing for initialization:

import jax
import jax.numpy as jnp
import contextlib

_INITIALIZING = False

@contextlib.contextmanager
def init_context():
  global _INITIALIZING
  assert not _INITIALIZING
  _INITIALIZING = True
  yield
  _INITIALIZING = False

def init_safe_scan(f, init, xs, length=None, default_scan=jax.lax.scan):
  # version of lax.scan that allows for use under flax/haiku initialization 
  if _INITIALIZING:  # could also use hk.running_init() here
    xs_flat, treedef = jax.tree_flatten(xs)
    if length is None:
      length, = {x.shape[0] for x in xs_flat}
    x0 = jax.tree_unflatten(treedef, [x[0] for x in xs_flat])
    carry, y0 = f(init, x0)
    ys = jax.tree_multimap(lambda *z: jnp.stack(z), *(length * [y0]))
    return carry, ys
  return default_scan(f, init, xs, length)

Then in Haiku, you can write something like:

import haiku as hk

def neural_net(x):
  return hk.Linear(5)(x)

def haiku_init_safe_scan(f, init, xs, length=None):
  return init_safe_scan(f, init, xs, length=None, default_scan=hk.scan)

def my_model(step_fn):
  def doubled_step(x):
    y, _ = haiku_init_safe_scan(
      lambda x, _: (step_fn(x), _), init=x, xs=jnp.arange(2))
    return y
  return doubled_step

rng = jax.random.PRNGKey(42)
x = jnp.ones([1, 5])
forward = hk.transform(my_model(neural_net))
with init_context():
  params = forward.init(rng, x)
print(params)  # only a single set of weights
logits = forward.apply(params, rng, x)
print(logits)  # does not crash

Presumably this sort of thing could be done for most/all higher order functions in JAX. It doesn’t even have to be library specific, so I can imagine this being a good fit for a third-party library or perhaps even a jax.experimental module (but not builtin to JAX core).

1reaction
shoyercommented, Feb 7, 2021

Yes, I’m happy to revive google/jax#4117 😃

My original motivation was actually exactly this issue: we have code that we want to support both Haiku and JAX. So far we’ve gotten around this by writing our own stateful version of higher order functions like jit for switching back and forth, but that isn’t very extensible.

Read more comments on GitHub >

github_iconTop Results From Across the Web

[D] What JAX NN library to use? : r/MachineLearning - Reddit
Well, it's both a lot easier to use and more general than any of the other frameworks. There was a reason I wrote...
Read more >
Haiku Contributors
Haiku is a library built on top of JAX designed to provide simple, composable abstractions for machine learning research. import haiku as hk....
Read more >
dm-haiku - PyPI
JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support. Haiku is a simple neural network library ......
Read more >
An Interview with Charlotte Digregorio - Evanston Public Library
Haiku is challenging for me and for most haikuists in that it is difficult to say a lot in so few words. After...
Read more >
My Word, it's 2019 | Plain and Fancy Girl - Marian Beaman
writing a book is more difficult than law school, running a marathon, or climbing a mountain!” Allison Leotta, author. She should know, she...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found