question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`lax.scan` is ~6x slower to run than hand-written loops

See original GitHub issue

This is on a GPU backend, haven’t tried on others.

https://colab.research.google.com/drive/1N1jrvGNFRhTnYLOCL6vXUzJN3pRR3pdr

Minimal repro below

from jax import numpy as np
from jax import grad, jit, vmap, lax
from jax import random as jax_random

import numpy as onp

@jit
def rewards_to_go(rewards, mask, gamma=0.99):
  r"""Computes rewards to go.
  Args:
    rewards: np.ndarray of shape (B, T) of rewards.
    mask: np.ndarray of shape (B, T) of mask for the rewards.
    gamma: float, discount factor.

  Returns:
    rewards to go, np.ndarray of shape (B, T).
  """
  B, T = rewards.shape  # pylint: disable=invalid-name,unused-variable

  masked_rewards = rewards * mask  # (B, T)

  # Compute r2g_{T-1} at the start and then compute backwards in time.
  r2gs = [masked_rewards[:, -1]]

  # Go from T-2 down to 0.
  for t in reversed(range(T - 1)):
    r2gs.append(masked_rewards[:, t] + (gamma * r2gs[-1]))

  # The list should have length T.
  assert T == len(r2gs)

  # First we stack them in the correct way to make it (B, T), but these are
  # still from newest (T-1) to oldest (0), so then we flip it on time axis.
  return np.flip(np.stack(r2gs, axis=1), axis=1)


@jit
def scan_rewards_to_go(rewards, mask, gamma=0.99):
  masked_rewards = rewards * mask  # (B, T)

  reversed_rewards = np.flip(masked_rewards, axis=1)  # (B, T) flipped on time.
  rrt = np.transpose(reversed_rewards)  # (T, B) transpose to scan over time.

  def discounting_add(carry, reward):
    x = reward + (gamma * carry)
    return x, x

  _, ys = lax.scan(discounting_add,
                   np.zeros_like(rrt[0], dtype=np.float32),
                   rrt.astype(np.float32))

  # ys is (T, B) and T is in reverse order.
  return np.flip(np.transpose(ys), axis=1)

B, T = 16, 128

num_examples = 100
rewards = []
pvs = []
mask = []

for _ in range(num_examples):
  rewards.append(onp.random.randn(B, T))
  pvs.append(onp.random.randn(B, T+1))
  ones = onp.full((B, T), 1, dtype=onp.int32)
  for one in ones:
    l = onp.random.randint(0, T)
    one[range(l,T)] = 0
  mask.append(ones)

Now time the invocations:

%timeit [rewards_to_go(rewards[i], mask[i]) for i in range(num_examples)]

and

%timeit [scan_rewards_to_go(rewards[i], mask[i]) for i in range(num_examples)]

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:15 (10 by maintainers)

github_iconTop GitHub Comments

1reaction
jekbradburycommented, Oct 30, 2019

I had a partial solution in the linked PR, but we ended up deciding against that approach and sketched out a more direct (and complete) solution two weeks ago. I haven’t implemented it yet, though it’s been on my todos; I expect I’ll have a new PR sometime in the next week or so.

0reactions
gehringcommented, Oct 29, 2019

@jekbradbury Any progress or updates on this? No pressure, I’m just trying to get an idea as to whether a fix is in the works or if I should commit to a workaround.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Issue with jax.lax.scan - python - Stack Overflow
I am supposed to use Jax. lax. scan instead of a for loop with 100 iterations at line 22. I am supposed to...
Read more >
jax.lax.scan - JAX documentation - Read the Docs
Scan a function over leading array axes while carrying along state. ... where we use [t] here to denote the type t with...
Read more >
Strengthening Forensic Science in the United States
(l ) the use of forensic evidence in criminal and civil litigation— o the collection and flow of evidence from crime scenes to...
Read more >
Untitled
Lakers douchebag gif, Brown bone earrings, Uva north carolina football game, Poly prep high school lacrosse, Pronostico extendido de canal 13, ...
Read more >
guide specification - Los Angeles World Airports
Infrastructure Standards of Practice Volumes 1, 2, and 3, dated April. 2016 or verify latest version with LAWA IMTG if current version is...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found