jax.numpy.linalg ops missing GPU implementation
See original GitHub issueI often hit messages like
"Singular value decomposition is only implemented on the CPU backend"
(https://github.com/google/jax/blob/27746b8c73f9ca9928da5da40b7382ae648a5f8d/jax/lax_linalg.py)
for many jax.numpy.linalg
ops. Examples I’ve hit so far are when calling:
jax.numpy.linalg.svd
;
jax.numpy.linalg.eigh
.
It would be nice to be able to run on GPU.
Thanks!
Issue Analytics
- State:
- Created 4 years ago
- Comments:7 (6 by maintainers)
Top Results From Across the Web
Change log - JAX documentation
XX has been changed to allocate XX% of the total GPU memory instead of the previous ... The gradients of svd and jax.numpy.linalg.pinv...
Read more >NEP 18 — A dispatch mechanism for NumPy's high level array ...
We would like a protocol that allows arguments of a NumPy function to take control and divert execution to another function (for example...
Read more >Benchmark: accelerating batch linear algebra operations with ...
I have run into one such situation lately: when implementing a probabilistic matrix factorization (PMF) recommender system, I have to multiply ...
Read more >Getting started with JAX (MLPs, CNNs & RNNs)
We simply import the JAX version of NumPy as well as the good old vanilla version. Most of the standard NumPy functons are...
Read more >Legate NumPy: Accelerated and Distributed Array Computing
mgarland@nvidia.com. ABSTRACT. NumPy is a popular Python library used for performing array- based numerical computations. The canonical implementation of.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
There is now an LU decomposition implementation that works on GPU. However, it may not be the most performant (it’s implemented in JAX itself). We still would do well to add a specialized GPU implementation that calls cuSolver or MAGMA.
As it happens, jax.numpy.linalg.inv should run on GPU right now! (It might not be terribly fast, since it’s using a QR decomposition instead of an LU decomposition, and one that isn’t necessarily that well tuned.)