question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

hk.stateful.remat generates excess un-pruneable HLO

See original GitHub issue

jax.remat wraps all of its inputs with _foil_cse.

When we do the state-threading in hk.stateful.remat, the threaded-out state now is the output of _foil_cse. Any downstream uses of this state now access the foil-cse’d param/state, rather than the original.

Example:

def f(x, ctxt):
  return jnp.sin(x + ctxt[0]), ctxt

@jax.jit
def g(x):
  ctxt = [x + i for i in range(2)]
  x, ctxt = jax.remat(f)(x, ctxt)
  return jnp.sin(x + ctxt[1])

g(1.).block_until_ready()

This results in HLO that looks like:

HloModule jit_g__1.46, is_scheduled=true

ENTRY jit_g__1.46 {
  constant.2 = f32[]{:T(256)} constant(2)
  constant = f32[]{:T(256)} constant(0)
  constant.5 = f32[]{:T(256)} constant(1)
  rng.2 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare.2 = pred[]{:T(256)E(32)} compare(rng.2, constant.2), direction=LT
  rng.1 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare.1 = pred[]{:T(256)E(32)} compare(rng.1, constant.2), direction=LT
  rng = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
  compare = pred[]{:T(256)E(32)} compare(rng, constant.2), direction=LT
  parameter.1 = f32[]{:T(256)} parameter(0), parameter_replication={false}
  select = f32[]{:T(256)} select(compare, parameter.1, constant)
  select.1 = f32[]{:T(256)} select(compare.1, parameter.1, constant)
  add = f32[]{:T(256)} add(select, select.1)
  sine = f32[]{:T(256)} sine(add)
  add.6 = f32[]{:T(256)} add(parameter.1, constant.5)
  select.2 = f32[]{:T(256)} select(compare.2, add.6, constant)
  add.43 = f32[]{:T(256)} add(sine, select.2)
  sine.44 = f32[]{:T(256)} sine(add.43)
  ROOT tuple.45 = (f32[]{:T(256)}) tuple(sine.44)
}

Possible solutions:

  1. Reduce the amount of state-motion in & out of stateful_fun, especially during apply.
    • Params are immutable during apply, don’t thread them in/out.
    • Pre-split the RNG and populate a new hk.PRNGSequence inside stateful_fun so that RNG doesn’t get threaded in/out.
    • state is only updated in-place for state that’s actually been changed. JAX referential transparency makes this challenging for the case in which Haiku is not jitted but internal functions are via hk.jit.
  2. Rebuild hk.remat on top of hk._src.lift.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
trevorcaicommented, Mar 4, 2020

I was thinking the same thing; we can view the returned value as an overlay over the original values. There are a couple of annoying things here to be aware of:

  • For state, updated states will still modify the original state member of StatePair to be the _foil_cse’d one. This is pretty annoying, though not the worst thing ever.
  • The _foil_cse HLO does not get pruned when not used - we’re left with a huge number of vestigial constant and rng HLO ops. Blake from XLA team speculates that this is because the ops could be side-effecting.

As a side note, it’s not clear to me whether _foil_cse-ing params (which don’t need to be _foil_cse’d, as they’re constant!) is even a good idea. This is sufficiently murky that I think Haiku shouldn’t take a stance on this.

0reactions
trevorcaicommented, May 1, 2020

As a result of google/jax#2391, it looks like the problems get optimized away; we’re not even incurring unnecessary serialization right now.

These changes may still be good to make defensively.

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found