question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.numpy.linalg ops missing GPU implementation

See original GitHub issue

I often hit messages like "Singular value decomposition is only implemented on the CPU backend" (https://github.com/google/jax/blob/27746b8c73f9ca9928da5da40b7382ae648a5f8d/jax/lax_linalg.py) for many jax.numpy.linalg ops. Examples I’ve hit so far are when calling: jax.numpy.linalg.svd; jax.numpy.linalg.eigh.

It would be nice to be able to run on GPU.

Thanks!

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:7 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Jun 28, 2019

There is now an LU decomposition implementation that works on GPU. However, it may not be the most performant (it’s implemented in JAX itself). We still would do well to add a specialized GPU implementation that calls cuSolver or MAGMA.

1reaction
hawkinspcommented, Jun 18, 2019

As it happens, jax.numpy.linalg.inv should run on GPU right now! (It might not be terribly fast, since it’s using a QR decomposition instead of an LU decomposition, and one that isn’t necessarily that well tuned.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Change log - JAX documentation
XX has been changed to allocate XX% of the total GPU memory instead of the previous ... The gradients of svd and jax.numpy.linalg.pinv...
Read more >
NEP 18 — A dispatch mechanism for NumPy's high level array ...
We would like a protocol that allows arguments of a NumPy function to take control and divert execution to another function (for example...
Read more >
Benchmark: accelerating batch linear algebra operations with ...
I have run into one such situation lately: when implementing a probabilistic matrix factorization (PMF) recommender system, I have to multiply ...
Read more >
Getting started with JAX (MLPs, CNNs & RNNs)
We simply import the JAX version of NumPy as well as the good old vanilla version. Most of the standard NumPy functons are...
Read more >
Legate NumPy: Accelerated and Distributed Array Computing
mgarland@nvidia.com. ABSTRACT. NumPy is a popular Python library used for performing array- based numerical computations. The canonical implementation of.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found