jax.lax.linalg.eigh on GPU and multi-core CPU doesn't parallel appropriately.
See original GitHub issueimport jax
import jax.numpy as jnp
def timer(f):
from time import time
f() # warmup and compile
t = time()
for _ in range(3):
f()
print((time() - t) / 3)
y = jax.random.uniform(jax.random.PRNGKey(0), (16, 1024, 1024)) / 16
s = jax.block_until_ready(y @ y.transpose(0, 2, 1) + jnp.eye(1024))
from jax.lax.linalg import eigh as jeigh
f = jax.jit(jax.vmap(jeigh))
timer(lambda: jax.block_until_ready(f(s))) # 0.90s for 16 problems
from scipy.linalg import eigh as seigh
import numpy as np
ss = np.array(s[0])
timer(lambda: seigh(ss)) # 0.21s for 1 problem
GPU: V100-PCIE 16G CPU: Intel® Xeon® CPU E5-2690 v4 @ 2.60GHz
jax.lax.linalg.eigh
on 1 GPU use 0.90s for 16 problems. on all CPU-core(top
report 2340% peak CPU usage) use 2.44s for 16 problems.
scipy.linalg.eigh
on 1 CPU-core(top
report 200% peak CPU usage) use 0.21s for 1 problem.
This result means that, GPU only have <4x throughput, and >11x CPU usage only have <1.4x throughput, while there should be a embarrassingly parallel given vmap
.
Issue Analytics
- State:
- Created a year ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
jax.lax.linalg.eig - JAX documentation
Eigendecomposition of a general matrix. Nonsymmetric eigendecomposition is at present only implemented on CPU. Parameters. x ( Union [ Array ...
Read more >Getting started with JAX - Towards Data Science
This is interesting since we are currently training on CPU which means that the 2D convolution can't be as easily parallelized as on...
Read more >JAX for Machine Learning: how it works and why learn it
JAX is the new kid in Machine Learning (ML) town and it promises to make ML programming more intuitive, structured, and clean.
Read more >jax_intro.ipynb - Colaboratory - Google Colab
JAX is a version of NumPy that runs fast on CPU, GPU and TPU, ... up correctly, the following command should return a...
Read more >Darren Wilkinson's blog – Statistics, computing, functional ...
Statistics, computing, functional programming, data science, Bayes, stochastic modelling, systems biology and bioinformatics.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Can we have pytorch-like
set_num_threads
andset_num_interop_threads
to control the parallel?I find that this(7.5s for
24*1024*320*320
) is 50x faster than JAX on 24-core CPU (15.6s for1024*320*320
) and 40x faster than naively let pytorch use intra-op parallelism with 24 threads(12.4s for1024*320*320
). — which is actually 1.8x slower than single thread(6.9s for1024*320*320
), 2.3x slower than 4 threads(5.4s for1024*320*320
).JAX will use multiple core, while scipy only use single core. But JAX with multiple core only has a bit speed up, at the cost of preventing user manually using spmd/data parallel.