question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

slow linalg.solve

See original GitHub issue

Dear jax team,

I’d like to perform a lot (1e5-1e7) np.linalg.solve operations on small (3x3) matrices. jax is a lot slower than the numpy version: 17s in jax vs. 382ms in numpy.

Is 3x3 too small to be effective on GPU? Or is the solve method just not as efficiently implemented for such a task as the numpy version?

import jax
import jax.numpy as np
import numpy as onp

def to_gpu(arr):
    return jax.device_put(arr.astype(np.float32))

num_matrices = int(1e6)
sidelen = 3

A = to_gpu(onp.random.normal(size=(num_matrices, sidelen, sidelen)))
b = to_gpu(onp.random.normal(size=(num_matrices, sidelen)))

solve_jit = jax.jit(np.linalg.solve)
solve_vmap = jax.vmap(np.linalg.solve)
solve_vmap_jit = jax.jit(jax.vmap(np.linalg.solve))

# run jit'ed versions once
_ = solve_jit(A, b)
_ = solve_vmap_jit(A, b)
                                     
%time np.linalg.solve(A, b).block_until_ready()   # 16.5 s
%time solve_jit(A, b).block_until_ready()         # 17.0 s
%time solve_vmap(A, b).block_until_ready()        # 17.0 s
%time solve_vmap_jit(A, b).block_until_ready()    # 15.3 s

%time onp.linalg.solve(A, b)  # 383 ms

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Aug 8, 2019

I have merged the fix, but it requires that you either rebuild Jaxlib yourself or wait for us to make a new Jaxlib release. (We’re probably just about due for one.)

There should be little difference between the vmap and non-vmaped version of np.linalg.solve; ultimately they will end up building very similar computations. You might observe differences in the absence of jit, but note that PR #1144 added a @jit decorator in the implementation of np.linalg.solve since it speeds things up.

Hope that helps!

1reaction
hawkinspcommented, Aug 8, 2019

I have change that adds support for batched LU decompositions and triangular solves on GPU to JAX. I now see on a GTX 1080:

In [16]: %time solve_jit(A, b).block_until_ready()         # 17.0 s
CPU times: user 21.4 ms, sys: 3.13 ms, total: 24.5 ms
Wall time: 23.5 ms

which is presumably more like what you wanted! I’ll clean it up and get it submitted.

Read more comments on GitHub >

github_iconTop Results From Across the Web

np.linalg.solve is slow #1 - bogovicj/JaneliaMLCourse - GitHub
To my surprise, np.linalg.solve is extremely slow running on J-cluster. conda 4.5.11 numpy 1.15.0 py36_blas_openblashd3ea46f_200 ...
Read more >
Why is scipy.linalg.LU so slow for solving Ax = b repeatedly?
So double use of solve_trianglar is slower than one solve , but is faster than using a solve that doesn't know that its...
Read more >
numpy.linalg.solve is 6x faster on my Mac than on my desktop ...
So my Mac runs numpy. linalg. solve a LOT faster than my desktop, despite the slower CPU. It's at LEAST 6 times faster...
Read more >
Linear Algebra (scipy.linalg)
Solving a linear system#. Solving linear systems of equations is straightforward using the scipy command linalg.solve . This command expects an input matrix...
Read more >
Matrix inverse running so slow - v4 - PyMC Discourse
I quickly tested np.linalg.inv on an (8000, 8000) square matrix, and it took about 5 seconds on my machine. If you need to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found