slow linalg.solve
See original GitHub issueDear jax team,
I’d like to perform a lot (1e5-1e7) np.linalg.solve
operations on small (3x3) matrices. jax is a lot slower than the numpy version: 17s in jax vs. 382ms in numpy.
Is 3x3 too small to be effective on GPU? Or is the solve
method just not as efficiently implemented for such a task as the numpy version?
import jax
import jax.numpy as np
import numpy as onp
def to_gpu(arr):
return jax.device_put(arr.astype(np.float32))
num_matrices = int(1e6)
sidelen = 3
A = to_gpu(onp.random.normal(size=(num_matrices, sidelen, sidelen)))
b = to_gpu(onp.random.normal(size=(num_matrices, sidelen)))
solve_jit = jax.jit(np.linalg.solve)
solve_vmap = jax.vmap(np.linalg.solve)
solve_vmap_jit = jax.jit(jax.vmap(np.linalg.solve))
# run jit'ed versions once
_ = solve_jit(A, b)
_ = solve_vmap_jit(A, b)
%time np.linalg.solve(A, b).block_until_ready() # 16.5 s
%time solve_jit(A, b).block_until_ready() # 17.0 s
%time solve_vmap(A, b).block_until_ready() # 17.0 s
%time solve_vmap_jit(A, b).block_until_ready() # 15.3 s
%time onp.linalg.solve(A, b) # 383 ms
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (6 by maintainers)
Top Results From Across the Web
np.linalg.solve is slow #1 - bogovicj/JaneliaMLCourse - GitHub
To my surprise, np.linalg.solve is extremely slow running on J-cluster. conda 4.5.11 numpy 1.15.0 py36_blas_openblashd3ea46f_200 ...
Read more >Why is scipy.linalg.LU so slow for solving Ax = b repeatedly?
So double use of solve_trianglar is slower than one solve , but is faster than using a solve that doesn't know that its...
Read more >numpy.linalg.solve is 6x faster on my Mac than on my desktop ...
So my Mac runs numpy. linalg. solve a LOT faster than my desktop, despite the slower CPU. It's at LEAST 6 times faster...
Read more >Linear Algebra (scipy.linalg)
Solving a linear system#. Solving linear systems of equations is straightforward using the scipy command linalg.solve . This command expects an input matrix...
Read more >Matrix inverse running so slow - v4 - PyMC Discourse
I quickly tested np.linalg.inv on an (8000, 8000) square matrix, and it took about 5 seconds on my machine. If you need to...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I have merged the fix, but it requires that you either rebuild Jaxlib yourself or wait for us to make a new Jaxlib release. (We’re probably just about due for one.)
There should be little difference between the
vmap
and non-vmap
ed version ofnp.linalg.solve
; ultimately they will end up building very similar computations. You might observe differences in the absence ofjit
, but note that PR #1144 added a@jit
decorator in the implementation ofnp.linalg.solve
since it speeds things up.Hope that helps!
I have change that adds support for batched LU decompositions and triangular solves on GPU to JAX. I now see on a GTX 1080:
which is presumably more like what you wanted! I’ll clean it up and get it submitted.