Surprisingly slow jax.lax.top_k
See original GitHub issueHi,
I was writing some functions with jax.lax.top_k
and saw that it was particularly slow. I would expect a top_k
to be quite fast as it just requires to go once through the data at end, is that right?
To get a better idea of it’s performance compared to other operations, I launched a quick benchmark (code snipped at the end of the issue).
I compared jax.lax.top_k
(and jax.lax.approx_max_k
) to jnp.argmax
and jnp.argsort
.
It used k=10 and a data of size (256, 80000) (this arbitrary choice comes from the type of sizes I am using in my applications).
I am running those on a GPU T4.
What I observed is that:
jax.lax.top_k
is almost as fast asjnp.argsort
for those parameters (0.04 vs 0.058) (I expected it to be much faster)jax.lax.top_k
is much slower thanjnp.argmax
(0.04 vs 0.0005) (I expected them to be almost as fast)
Observing such a difference with jnp.argmax
, I decided to implement a naïve top_k, using argmax, which is hence faster than top_k (and approx_max_k) for small values of k (with my data size). And indeed, in my use case, this simple implementation (provided in the code snipped below, named naive_top_k
) is 7 times faster than top_k (0.04 vs 0.0056).
Is this a known behavior? Is there a reason for this?
Would be very interested in some insight about this (and a potential fix in the future 🙂)!
import time
import jax
import jax.numpy as jnp
import numpy as np
if __name__ == "__main__":
# create a random key
key = jax.random.PRNGKey(seed=0)
# create some random data
data = jax.random.uniform(key, shape=(256, 80000)).block_until_ready()
# create jitted functions with fixed k
k = 10
@jax.jit
def jitted_top_k(data):
values, _indices = jax.lax.top_k(data, k=10)
return values
@jax.jit
def jitted_approx_max_k(data):
values, _indices = jax.lax.approx_max_k(data, k=10)
return values
jitted_argmax = jax.jit(jnp.argmax)
jitted_argsort = jax.jit(jnp.argsort)
# Let's benchmark those functions
N = 20
M = 5 # avoid taking the first times
times = []
for i in range(N):
start_time = time.time()
jitted_top_k(data).block_until_ready()
elapsed_time = time.time() - start_time
if i >= M:
times.append(elapsed_time)
print("Time for jax.lax.top_k : ", np.mean(times))
times = []
for i in range(N):
start_time = time.time()
jitted_approx_max_k(data).block_until_ready()
elapsed_time = time.time() - start_time
if i >= M:
times.append(elapsed_time)
print("Time for jax.lax.approx_max_k : ", np.mean(times))
times = []
for i in range(N):
start_time = time.time()
jitted_argsort(data).block_until_ready()
elapsed_time = time.time() - start_time
if i >= M:
times.append(elapsed_time)
print("Time for jnp.argsort : ", np.mean(times))
times = []
for i in range(N):
start_time = time.time()
jitted_argmax(data).block_until_ready()
elapsed_time = time.time() - start_time
if i >= M:
times.append(elapsed_time)
print("Time for jnp.argmax : ", np.mean(times))
# Surprisingly, top_k and approx_max_k are almost
# as fast as argsort
# Surpringly, top_k and approx_max_k are much
# slower that argmax
# Let's build a top_k with argmax
def naive_top_k(data, k):
"""Top k implementation built with argmax.
Faster for smaller k."""
def top_1(data):
indice = jnp.argmax(data, axis=1)
value = jax.vmap(lambda x, y: x[y])(data, indice)
data = jax.vmap(lambda x, y: x.at[y].set(-jnp.inf))(data, indice)
return data, value, indice
def scannable_top_1(carry, unused):
data = carry
data, value, indice = top_1(data)
return data, (value, indice)
data, (values, indices) = jax.lax.scan(scannable_top_1, data, (), k)
return values.T, indices.T
@jax.jit
def jitted_naive_top_k(data):
values, _indices = naive_top_k(data, k=10)
return values
# benchmark our new top k
times = []
for i in range(N):
start_time = time.time()
jitted_naive_top_k(data).block_until_ready()
elapsed_time = time.time() - start_time
if i >= M:
times.append(elapsed_time)
print("Time for naive top k : ", np.mean(times))
My output with a GPU T4:
Time for jax.lax.top_k : 0.04046181042989095
Time for jax.lax.approx_max_k : 0.02177149454752604
Time for jnp.argsort : 0.0582158088684082
Time for jnp.argmax : 0.0005081971486409505
Time for naive top k : 0.005642461776733399
Thanks!
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (1 by maintainers)
We implemented
jax.lax.approx_max_k
for TPU to address the slowness of top-k. (Up to 148x faster in some scenarios.)Maybe we can prioritize to implement similar algorithms for GPU in the next quarter.
Hi there 👋
Any news @LenaMartens @dryman?