question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

vmap is much slower than manual vectorization

See original GitHub issue

Please see below simple example and I tested the performance in V100 GPU: the speed of manual vectorization is 3X of using vmap to do auto vectorization, is it expected? Thanks!

import ast
import jax.numpy as jnp
from jax import jit, random, vmap
import time

def parse_args():
    import argparse
    parser = argparse.ArgumentParser(description='Jax batching performance test')
    parser.add_argument("--vmap", default=False, type=ast.literal_eval, help="Whether using vmap for vectorization")
    parser.add_argument("--jit", default=False, type=ast.literal_eval, help="Whether do jit for function")
    parser.add_argument('--T', default=1000, type=int, help='Run times')
    args = parser.parse_args()
    return args

args = parse_args()

def predict(w, x):
    #print(x.ndim)
    outputs = jnp.matmul(x, w)
    return outputs

func = jit(predict) if args.jit else predict

key = random.PRNGKey(0)
x = random.normal(key, (256,512))
w = random.normal(key, (512,512))

start = time.time()
for _ in range(args.T):
    if args.vmap:
        vmap(func, in_axes=(None, 0))(w, x)
    else:
        func(w, x)
stop = time.time()

print("avg time: %f us" % ((stop - start) * 1e6 / args.T))

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:9 (4 by maintainers)

github_iconTop GitHub Comments

2reactions
froystigcommented, Apr 2, 2021

I see a relative slowdown on CPU too, if there’s no jit or when nesting vmap-of-jit:

$ cat test.py
from functools import partial
from timeit import timeit
from jax import vmap, jit, random, numpy as jnp

n, d = 512, 64
a = random.normal(random.PRNGKey(0), (n, d))
b = random.normal(random.PRNGKey(0), (d, d))

mm = jnp.matmul
v = partial(vmap, in_axes=(0, None))

for f in [mm, v(mm), jit(mm), v(jit(mm)), jit(v(mm))]:
  run = lambda: f(a, b).block_until_ready()
  t = timeit(run, setup=run, number=1000)
  print(f'{t:.3f}')
$ python test.py
0.308
0.545
0.150
0.344
0.149

As a workaround: it seems that using jit outside vmap rather than the other way around recovers the expected performance.

1reaction
mattjjcommented, Apr 2, 2021

Yeah, it’s basically overheads.

Here’s a vmap-of-jit profile on GPU (using the unreleased jaxlib==0.1.65, on a colab instance so probably extra high overheads):

image

Here’s jit-of-vmap:

image

With jit-of-vmap, we hit a very fast path: we the call to the jitted function jumps straight into C++, and that immediately calls into the JAX runtime (called PjRt) to enqueue a volta sgemm kernel.

The vmap-of-jit path has a lot more going on. First we do Python work in the vmap wrapper (e.g. batching.py’s batchfun). Then it hits the C++ jit dispatch path, but that has to bail out (by calling api.py’s cache_miss) because the C++ path can’t handle vmap tracers as arguments. That leads to a bunch more Python work (call_bind, process_call, call_bind, …) until we finally execute the XLA computation from Python _execute_compiled and get into PjRt. Then there’s a similar amount of Python work returning from the computation.

We plan to improve the C++ jit dispatch path to handle tracer inputs, which would cut out some of that overhead, but there’d still be a decent amount of the vmap overhead left. So it’ll always be best to put jit on the outside if you can!

(Notice that in both these cases the GPU is only doing useful work for a fraction of the time. That’s just because this is a really small computation.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Jax vectorization: vmap and/or numpy.vectorize?
vmap is to map a function over one or more inputs along a single explicit axis, as specified by the in_axes parameter. On...
Read more >
Speeding up code with vectorization - Kevin Chen
In the best case, the compiler wastes a few instructions shuffling data around. In the worst case, the interleaved data prevents the compiler ......
Read more >
When vectorization hits the memory wall
What is the problem with vector gather instructions? Memory is much slower than CPU, and many times CPU has to wait for the...
Read more >
Why is vectorizing way slower than the for-next loop it replaced?
The results of the two version are identical, but the vectorized version is way slower. By the time the whole program converges, this...
Read more >
Vectorized Loop - an overview | ScienceDirect Topics
Unfortunately, fully unsupervised auto-vectorization tends to produce suboptimal code (although different compilers continuously improve support) and thus is ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found