question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Surprisingly slow jax.lax.top_k

See original GitHub issue

Hi,

I was writing some functions with jax.lax.top_k and saw that it was particularly slow. I would expect a top_k to be quite fast as it just requires to go once through the data at end, is that right?

To get a better idea of it’s performance compared to other operations, I launched a quick benchmark (code snipped at the end of the issue).

I compared jax.lax.top_k (and jax.lax.approx_max_k) to jnp.argmax and jnp.argsort. It used k=10 and a data of size (256, 80000) (this arbitrary choice comes from the type of sizes I am using in my applications). I am running those on a GPU T4.

What I observed is that:

  • jax.lax.top_k is almost as fast as jnp.argsort for those parameters (0.04 vs 0.058) (I expected it to be much faster)
  • jax.lax.top_k is much slower than jnp.argmax (0.04 vs 0.0005) (I expected them to be almost as fast)

Observing such a difference with jnp.argmax, I decided to implement a naïve top_k, using argmax, which is hence faster than top_k (and approx_max_k) for small values of k (with my data size). And indeed, in my use case, this simple implementation (provided in the code snipped below, named naive_top_k) is 7 times faster than top_k (0.04 vs 0.0056).

Is this a known behavior? Is there a reason for this?

Would be very interested in some insight about this (and a potential fix in the future 🙂)!

import time

import jax
import jax.numpy as jnp
import numpy as np

if __name__ == "__main__":

    # create a random key
    key = jax.random.PRNGKey(seed=0)

    # create some random data
    data = jax.random.uniform(key, shape=(256, 80000)).block_until_ready()

    # create jitted functions with fixed k
    k = 10

    @jax.jit
    def jitted_top_k(data):
        values, _indices = jax.lax.top_k(data, k=10)
        return values

    @jax.jit
    def jitted_approx_max_k(data):
        values, _indices = jax.lax.approx_max_k(data, k=10)
        return values

    jitted_argmax = jax.jit(jnp.argmax)
    jitted_argsort = jax.jit(jnp.argsort)

    # Let's benchmark those functions

    N = 20
    M = 5  # avoid taking the first times

    times = []
    for i in range(N):
        start_time = time.time()
        jitted_top_k(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for jax.lax.top_k : ", np.mean(times))

    times = []
    for i in range(N):
        start_time = time.time()
        jitted_approx_max_k(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for jax.lax.approx_max_k : ", np.mean(times))

    times = []
    for i in range(N):
        start_time = time.time()
        jitted_argsort(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for jnp.argsort : ", np.mean(times))

    times = []
    for i in range(N):
        start_time = time.time()
        jitted_argmax(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for jnp.argmax : ", np.mean(times))

    # Surprisingly, top_k and approx_max_k are almost
    # as fast as argsort

    # Surpringly, top_k and approx_max_k are much
    # slower that argmax

    # Let's build a top_k with argmax

    def naive_top_k(data, k):
        """Top k implementation built with argmax.
        Faster for smaller k."""

        def top_1(data):
            indice = jnp.argmax(data, axis=1)
            value = jax.vmap(lambda x, y: x[y])(data, indice)
            data = jax.vmap(lambda x, y: x.at[y].set(-jnp.inf))(data, indice)
            return data, value, indice

        def scannable_top_1(carry, unused):
            data = carry
            data, value, indice = top_1(data)
            return data, (value, indice)

        data, (values, indices) = jax.lax.scan(scannable_top_1, data, (), k)

        return values.T, indices.T

    @jax.jit
    def jitted_naive_top_k(data):
        values, _indices = naive_top_k(data, k=10)
        return values

    # benchmark our new top k
    times = []
    for i in range(N):
        start_time = time.time()
        jitted_naive_top_k(data).block_until_ready()
        elapsed_time = time.time() - start_time

        if i >= M:
            times.append(elapsed_time)

    print("Time for naive top k : ", np.mean(times))

My output with a GPU T4:

Time for jax.lax.top_k :  0.04046181042989095
Time for jax.lax.approx_max_k :  0.02177149454752604
Time for jnp.argsort :  0.0582158088684082
Time for jnp.argmax :  0.0005081971486409505
Time for naive top k :  0.005642461776733399

Thanks!

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:7 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
drymancommented, Apr 1, 2022

We implemented jax.lax.approx_max_k for TPU to address the slowness of top-k. (Up to 148x faster in some scenarios.)

Maybe we can prioritize to implement similar algorithms for GPU in the next quarter.

0reactions
felixchalumeaucommented, Nov 2, 2022

Hi there 👋

Any news @LenaMartens @dryman?

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.top_k - JAX documentation
Returns top k values and their indices along the last axis of operand . Parameters. operand ( Union [ Array , ndarray ,...
Read more >
Newest 'jax' Questions - Page 4 - Stack Overflow
for a project I am trying to code up a very simple MLP example, but I noticed that the implementation in flax is...
Read more >
JAX is for Joy, AutoDiff, and Xeleration - Jan Ebert
Sadly, I was unable to compile JAX with GPU support; the version of CUDA the ... jax.lax: Lower-level operations. jax.profiler: Analyze performance.
Read more >
How New England Patriots Can Slow Peyton Manning ...
Not surprisingly, he was asked what his Patriots needed to do in order to stop ... lacing Manning's pregame meal with Ex-lax or...
Read more >
Cheap Flights from Albania to Jacksonville from $920 - KAYAK
Jacksonville (JAX) ... Top tips for finding cheap flights to Jacksonville ... Pilot ran out of flight time allowed him, probably due to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found