index-dependent scan function `lax.scani`
See original GitHub issueI am interested in training recurrent networks for which the transition dynamics have some sort of time-dependence. For example, the network might evolve linear from time t1=0
to time t2
and is clamped at some constant parameter array u
from then on. In normal python code I might write some thing like this
for step in range(n_steps):
x = a.dot(x) if step < t2 else u
I would like to differentiate through these dynamics using reverse-mode, so I’ve been trying to use lax.scan
.
However, I’m not sure how to introduce time-dependence into the scanning function f
. Right now, I’ve defined two transition functions f1
and f2
one for each of the two cases:
carry, _ = lax.scan(f1, x0, length=t2)
carry, _ = lax.scan(f2, carry, length=n_steps - t2)
This would get quite annoying when my transition dynamics is much more complicated.
Therefore, I was wondering if it would be possible to have a function lax.scani
which takes a scanning function f
with type signature f : int -> c -> a -> (c, b)
where the first argument of f
is the index of the element it is scanning; and importantly, we can use this integer index to do control flow. In the example above, we would have
def f(t, carry, x):
return a.dot(carry) if t < t2 else u
carry, _ = lax.scani(f, x0, length=n_steps)
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (3 by maintainers)
@tachukao yes, using
lax.cond
the control flow you write can always be staged out (i.e. by jit, or use in a scan body) and also differentiated. It’s awkward, but it’s the only robust way we’ve found to embed structured control flow in Python.You can always avoid all this structured control flow stuff (
lax.scan
,lax.cond
, etc) and write things with regular Python for-loops and ifs. JAX can differentiate native Python! But if you usejit
on a Python loop, compile times may get long (because the loop is essentially unrolled into the XLA computation). (The purpose oflax.scan
is to stage out a loop construct to XLA (without unrolling) and thus give good compile times.)Here’s sketch code for how you might write it so that the loop and other control flow stays in Python, but you can still use
jit
on some parts:You only need to write things in terms of
lax.scan
/lax.cond
if you need more performance because you want tojit
the wholernn
function.If we introduced a
lax.scani
kind of function, it’d just be a wrapper aroundlax.scan
andlax.cond
, but our policy is to avoid wrappers unless they’re very commonly needed.I think we covered the original question, so I’m going to close this issue (otherwise we’ll drown in issues!), but please open a new one if you have new questions!
Hi Neil, thanks for the suggestion - I certainly can. I guess the problem I have now is just that I need to figure out how to use
lax.cond
to do control flow on the time indexi
in a way that is differentiable, as @mattjj suggested above. This I haven’t really explored.