`lax.scan` is ~6x slower to run than hand-written loops
See original GitHub issueThis is on a GPU backend, haven’t tried on others.
https://colab.research.google.com/drive/1N1jrvGNFRhTnYLOCL6vXUzJN3pRR3pdr
Minimal repro below
from jax import numpy as np
from jax import grad, jit, vmap, lax
from jax import random as jax_random
import numpy as onp
@jit
def rewards_to_go(rewards, mask, gamma=0.99):
r"""Computes rewards to go.
Args:
rewards: np.ndarray of shape (B, T) of rewards.
mask: np.ndarray of shape (B, T) of mask for the rewards.
gamma: float, discount factor.
Returns:
rewards to go, np.ndarray of shape (B, T).
"""
B, T = rewards.shape # pylint: disable=invalid-name,unused-variable
masked_rewards = rewards * mask # (B, T)
# Compute r2g_{T-1} at the start and then compute backwards in time.
r2gs = [masked_rewards[:, -1]]
# Go from T-2 down to 0.
for t in reversed(range(T - 1)):
r2gs.append(masked_rewards[:, t] + (gamma * r2gs[-1]))
# The list should have length T.
assert T == len(r2gs)
# First we stack them in the correct way to make it (B, T), but these are
# still from newest (T-1) to oldest (0), so then we flip it on time axis.
return np.flip(np.stack(r2gs, axis=1), axis=1)
@jit
def scan_rewards_to_go(rewards, mask, gamma=0.99):
masked_rewards = rewards * mask # (B, T)
reversed_rewards = np.flip(masked_rewards, axis=1) # (B, T) flipped on time.
rrt = np.transpose(reversed_rewards) # (T, B) transpose to scan over time.
def discounting_add(carry, reward):
x = reward + (gamma * carry)
return x, x
_, ys = lax.scan(discounting_add,
np.zeros_like(rrt[0], dtype=np.float32),
rrt.astype(np.float32))
# ys is (T, B) and T is in reverse order.
return np.flip(np.transpose(ys), axis=1)
B, T = 16, 128
num_examples = 100
rewards = []
pvs = []
mask = []
for _ in range(num_examples):
rewards.append(onp.random.randn(B, T))
pvs.append(onp.random.randn(B, T+1))
ones = onp.full((B, T), 1, dtype=onp.int32)
for one in ones:
l = onp.random.randint(0, T)
one[range(l,T)] = 0
mask.append(ones)
Now time the invocations:
%timeit [rewards_to_go(rewards[i], mask[i]) for i in range(num_examples)]
and
%timeit [scan_rewards_to_go(rewards[i], mask[i]) for i in range(num_examples)]
Issue Analytics
- State:
- Created 4 years ago
- Comments:15 (10 by maintainers)
Top Results From Across the Web
Issue with jax.lax.scan - python - Stack Overflow
I am supposed to use Jax. lax. scan instead of a for loop with 100 iterations at line 22. I am supposed to...
Read more >jax.lax.scan - JAX documentation - Read the Docs
Scan a function over leading array axes while carrying along state. ... where we use [t] here to denote the type t with...
Read more >Strengthening Forensic Science in the United States
(l ) the use of forensic evidence in criminal and civil litigation— o the collection and flow of evidence from crime scenes to...
Read more >Untitled
Lakers douchebag gif, Brown bone earrings, Uva north carolina football game, Poly prep high school lacrosse, Pronostico extendido de canal 13, ...
Read more >guide specification - Los Angeles World Airports
Infrastructure Standards of Practice Volumes 1, 2, and 3, dated April. 2016 or verify latest version with LAWA IMTG if current version is...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I had a partial solution in the linked PR, but we ended up deciding against that approach and sketched out a more direct (and complete) solution two weeks ago. I haven’t implemented it yet, though it’s been on my todos; I expect I’ll have a new PR sometime in the next week or so.
@jekbradbury Any progress or updates on this? No pressure, I’m just trying to get an idea as to whether a fix is in the works or if I should commit to a workaround.