hk.stateful.remat generates excess un-pruneable HLO
See original GitHub issuejax.remat
wraps all of its inputs with _foil_cse
.
When we do the state-threading in hk.stateful.remat
, the threaded-out state now is the output of _foil_cse
. Any downstream uses of this state now access the foil-cse’d param/state, rather than the original.
Example:
def f(x, ctxt):
return jnp.sin(x + ctxt[0]), ctxt
@jax.jit
def g(x):
ctxt = [x + i for i in range(2)]
x, ctxt = jax.remat(f)(x, ctxt)
return jnp.sin(x + ctxt[1])
g(1.).block_until_ready()
This results in HLO that looks like:
HloModule jit_g__1.46, is_scheduled=true
ENTRY jit_g__1.46 {
constant.2 = f32[]{:T(256)} constant(2)
constant = f32[]{:T(256)} constant(0)
constant.5 = f32[]{:T(256)} constant(1)
rng.2 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
compare.2 = pred[]{:T(256)E(32)} compare(rng.2, constant.2), direction=LT
rng.1 = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
compare.1 = pred[]{:T(256)E(32)} compare(rng.1, constant.2), direction=LT
rng = f32[]{:T(256)} rng(constant, constant.5), distribution=rng_uniform
compare = pred[]{:T(256)E(32)} compare(rng, constant.2), direction=LT
parameter.1 = f32[]{:T(256)} parameter(0), parameter_replication={false}
select = f32[]{:T(256)} select(compare, parameter.1, constant)
select.1 = f32[]{:T(256)} select(compare.1, parameter.1, constant)
add = f32[]{:T(256)} add(select, select.1)
sine = f32[]{:T(256)} sine(add)
add.6 = f32[]{:T(256)} add(parameter.1, constant.5)
select.2 = f32[]{:T(256)} select(compare.2, add.6, constant)
add.43 = f32[]{:T(256)} add(sine, select.2)
sine.44 = f32[]{:T(256)} sine(add.43)
ROOT tuple.45 = (f32[]{:T(256)}) tuple(sine.44)
}
Possible solutions:
- Reduce the amount of state-motion in & out of
stateful_fun
, especially duringapply
.- Params are immutable during apply, don’t thread them in/out.
- Pre-split the RNG and populate a new
hk.PRNGSequence
insidestateful_fun
so that RNG doesn’t get threaded in/out. state
is only updated in-place for state that’s actually been changed. JAX referential transparency makes this challenging for the case in which Haiku is not jitted but internal functions are viahk.jit
.
- Rebuild
hk.remat
on top ofhk._src.lift
.
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
No results found
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I was thinking the same thing; we can view the returned value as an overlay over the original values. There are a couple of annoying things here to be aware of:
StatePair
to be the _foil_cse’d one. This is pretty annoying, though not the worst thing ever.constant
andrng
HLO ops. Blake from XLA team speculates that this is because the ops could be side-effecting.As a side note, it’s not clear to me whether _foil_cse-ing params (which don’t need to be _foil_cse’d, as they’re constant!) is even a good idea. This is sufficiently murky that I think Haiku shouldn’t take a stance on this.
As a result of google/jax#2391, it looks like the problems get optimized away; we’re not even incurring unnecessary serialization right now.
These changes may still be good to make defensively.