question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.lax.linalg.eigh on GPU and multi-core CPU doesn't parallel appropriately.

See original GitHub issue
import jax
import jax.numpy as jnp

def timer(f):
    from time import time
    f() # warmup and compile
    t = time()
    for _ in range(3):
        f()
    print((time() - t) / 3)

y = jax.random.uniform(jax.random.PRNGKey(0), (16, 1024, 1024)) / 16
s = jax.block_until_ready(y @ y.transpose(0, 2, 1) + jnp.eye(1024))

from jax.lax.linalg import eigh as jeigh
f = jax.jit(jax.vmap(jeigh))
timer(lambda: jax.block_until_ready(f(s))) # 0.90s for 16 problems

from scipy.linalg import eigh as seigh
import numpy as np
ss = np.array(s[0])
timer(lambda: seigh(ss)) # 0.21s for 1 problem

GPU: V100-PCIE 16G CPU: Intel® Xeon® CPU E5-2690 v4 @ 2.60GHz

jax.lax.linalg.eigh on 1 GPU use 0.90s for 16 problems. on all CPU-core(top report 2340% peak CPU usage) use 2.44s for 16 problems. scipy.linalg.eigh on 1 CPU-core(top report 200% peak CPU usage) use 0.21s for 1 problem. This result means that, GPU only have <4x throughput, and >11x CPU usage only have <1.4x throughput, while there should be a embarrassingly parallel given vmap.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

3reactions
YouJiachengcommented, Apr 7, 2022

Can we have pytorch-like set_num_threads and set_num_interop_threads to control the parallel?

import torch

torch.set_num_threads(1)
torch.set_num_interop_threads(24)

@torch.jit.script
def mt_eigh(x: torch.Tensor):
    futs = [torch.jit._fork(torch.linalg.eigh, x[i]) for i in range(24)]
    return [torch.jit._wait(fut) for fut in futs]

I find that this(7.5s for 24*1024*320*320) is 50x faster than JAX on 24-core CPU (15.6s for 1024*320*320) and 40x faster than naively let pytorch use intra-op parallelism with 24 threads(12.4s for 1024*320*320). — which is actually 1.8x slower than single thread(6.9s for 1024*320*320), 2.3x slower than 4 threads(5.4s for 1024*320*320).

0reactions
YouJiachengcommented, Apr 7, 2022

In fact, it’s a LAPACK function we use provided by scipy, so I’d be surprised if you saw any speedup over scipy at all. That said, the algorithm does use parallelism internally at least for some of the phases.

JAX will use multiple core, while scipy only use single core. But JAX with multiple core only has a bit speed up, at the cost of preventing user manually using spmd/data parallel.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.linalg.eig - JAX documentation
Eigendecomposition of a general matrix. Nonsymmetric eigendecomposition is at present only implemented on CPU. Parameters. x ( Union [ Array ...
Read more >
Getting started with JAX - Towards Data Science
This is interesting since we are currently training on CPU which means that the 2D convolution can't be as easily parallelized as on...
Read more >
JAX for Machine Learning: how it works and why learn it
JAX is the new kid in Machine Learning (ML) town and it promises to make ML programming more intuitive, structured, and clean.
Read more >
jax_intro.ipynb - Colaboratory - Google Colab
JAX is a version of NumPy that runs fast on CPU, GPU and TPU, ... up correctly, the following command should return a...
Read more >
Darren Wilkinson's blog – Statistics, computing, functional ...
Statistics, computing, functional programming, data science, Bayes, stochastic modelling, systems biology and bioinformatics.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found