question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

index-dependent scan function `lax.scani`

See original GitHub issue

I am interested in training recurrent networks for which the transition dynamics have some sort of time-dependence. For example, the network might evolve linear from time t1=0 to time t2 and is clamped at some constant parameter array u from then on. In normal python code I might write some thing like this

for step in range(n_steps):
  x = a.dot(x) if step < t2 else u

I would like to differentiate through these dynamics using reverse-mode, so I’ve been trying to use lax.scan. However, I’m not sure how to introduce time-dependence into the scanning function f. Right now, I’ve defined two transition functions f1 and f2 one for each of the two cases:

carry, _ = lax.scan(f1, x0, length=t2)
carry, _ = lax.scan(f2, carry, length=n_steps - t2)

This would get quite annoying when my transition dynamics is much more complicated.

Therefore, I was wondering if it would be possible to have a function lax.scani which takes a scanning function f with type signature f : int -> c -> a -> (c, b) where the first argument of f is the index of the element it is scanning; and importantly, we can use this integer index to do control flow. In the example above, we would have

def f(t, carry, x):
   return a.dot(carry) if t < t2 else u

carry, _ = lax.scani(f, x0, length=n_steps)

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Mar 10, 2020

@tachukao yes, using lax.cond the control flow you write can always be staged out (i.e. by jit, or use in a scan body) and also differentiated. It’s awkward, but it’s the only robust way we’ve found to embed structured control flow in Python.

You can always avoid all this structured control flow stuff (lax.scan, lax.cond, etc) and write things with regular Python for-loops and ifs. JAX can differentiate native Python! But if you use jit on a Python loop, compile times may get long (because the loop is essentially unrolled into the XLA computation). (The purpose of lax.scan is to stage out a loop construct to XLA (without unrolling) and thus give good compile times.)

Here’s sketch code for how you might write it so that the loop and other control flow stays in Python, but you can still use jit on some parts:

from functools import partial
from jax import jit

@jit
def f(params, hidden, x):
  ...

@jit 
def g(params, hidden, x):
  ...

...


def rnn(params, hidden, inputs):
  for i, x in enumerate(inputs):
    if i % 10 == 0:
      hidden, y = f(params, hidden, x)
    elif i % 10 == 1:
      hidden, y = g(params, hidden, x)
    elif ...
    outputs.append(y)
  return hidden, outputs

You only need to write things in terms of lax.scan/lax.cond if you need more performance because you want to jit the whole rnn function.

If we introduced a lax.scani kind of function, it’d just be a wrapper around lax.scan and lax.cond, but our policy is to avoid wrappers unless they’re very commonly needed.

I think we covered the original question, so I’m going to close this issue (otherwise we’ll drown in issues!), but please open a new one if you have new questions!

1reaction
tachukaocommented, Mar 9, 2020

Hi Neil, thanks for the suggestion - I certainly can. I guess the problem I have now is just that I need to figure out how to use lax.cond to do control flow on the time index i in a way that is differentiable, as @mattjj suggested above. This I haven’t really explored.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.scan - JAX documentation - Read the Docs
Scan a function over leading array axes while carrying along state. The Haskell-like type signature in brief is. scan :: (c -> a...
Read more >
Rewriting for loop with jax.lax.scan - python
I'm having troubles understanding the JAX documentation. Can somebody give me a hint on how to rewrite simple code like this with jax.lax.scan...
Read more >
Making a for-loop more efficient - numpyro
Honestly I'd be very happy to use fori_loop as I understand better how it works than scan() and I haven't figured the issue...
Read more >
Common Gotchas in JAX - Colaboratory
This is because JAX now invokes a cached compilation of the function ... lax.scan def func11(arr, extra): ones = jnp.ones(arr.shape)
Read more >
Model function to calculate the refractive index of native ...
The real part of the complex refractive index of oxygenated native hemoglobin solutions dependent on concentration was determined in the wavelength range 250...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found