vmap is much slower than manual vectorization
See original GitHub issuePlease see below simple example and I tested the performance in V100 GPU: the speed of manual vectorization is 3X of using vmap to do auto vectorization, is it expected? Thanks!
import ast
import jax.numpy as jnp
from jax import jit, random, vmap
import time
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='Jax batching performance test')
parser.add_argument("--vmap", default=False, type=ast.literal_eval, help="Whether using vmap for vectorization")
parser.add_argument("--jit", default=False, type=ast.literal_eval, help="Whether do jit for function")
parser.add_argument('--T', default=1000, type=int, help='Run times')
args = parser.parse_args()
return args
args = parse_args()
def predict(w, x):
#print(x.ndim)
outputs = jnp.matmul(x, w)
return outputs
func = jit(predict) if args.jit else predict
key = random.PRNGKey(0)
x = random.normal(key, (256,512))
w = random.normal(key, (512,512))
start = time.time()
for _ in range(args.T):
if args.vmap:
vmap(func, in_axes=(None, 0))(w, x)
else:
func(w, x)
stop = time.time()
print("avg time: %f us" % ((stop - start) * 1e6 / args.T))
Issue Analytics
- State:
- Created 2 years ago
- Comments:9 (4 by maintainers)
Top Results From Across the Web
Jax vectorization: vmap and/or numpy.vectorize?
vmap is to map a function over one or more inputs along a single explicit axis, as specified by the in_axes parameter. On...
Read more >Speeding up code with vectorization - Kevin Chen
In the best case, the compiler wastes a few instructions shuffling data around. In the worst case, the interleaved data prevents the compiler ......
Read more >When vectorization hits the memory wall
What is the problem with vector gather instructions? Memory is much slower than CPU, and many times CPU has to wait for the...
Read more >Why is vectorizing way slower than the for-next loop it replaced?
The results of the two version are identical, but the vectorized version is way slower. By the time the whole program converges, this...
Read more >Vectorized Loop - an overview | ScienceDirect Topics
Unfortunately, fully unsupervised auto-vectorization tends to produce suboptimal code (although different compilers continuously improve support) and thus is ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I see a relative slowdown on CPU too, if there’s no
jit
or when nestingvmap
-of-jit
:As a workaround: it seems that using
jit
outsidevmap
rather than the other way around recovers the expected performance.Yeah, it’s basically overheads.
Here’s a vmap-of-jit profile on GPU (using the unreleased jaxlib==0.1.65, on a colab instance so probably extra high overheads):
Here’s jit-of-vmap:
With jit-of-vmap, we hit a very fast path: we the call to the jitted function jumps straight into C++, and that immediately calls into the JAX runtime (called PjRt) to enqueue a volta sgemm kernel.
The vmap-of-jit path has a lot more going on. First we do Python work in the vmap wrapper (e.g. batching.py’s
batchfun
). Then it hits the C++ jit dispatch path, but that has to bail out (by calling api.py’scache_miss
) because the C++ path can’t handle vmap tracers as arguments. That leads to a bunch more Python work (call_bind
,process_call
,call_bind
, …) until we finally execute the XLA computation from Python_execute_compiled
and get into PjRt. Then there’s a similar amount of Python work returning from the computation.We plan to improve the C++ jit dispatch path to handle tracer inputs, which would cut out some of that overhead, but there’d still be a decent amount of the vmap overhead left. So it’ll always be best to put jit on the outside if you can!
(Notice that in both these cases the GPU is only doing useful work for a fraction of the time. That’s just because this is a really small computation.)