question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Jax support for sparse matrices

See original GitHub issue

Hi all,

I was discussing with @jakevdp in #6466 about adding support for sparse matrices in Jax. He already did quite some groundwork, but there’s a lot of open questions on the what and how. Quoting Jake here:

  • 2D CSC/CSR/COO may not be the abstraction we want in the end… pydata-sparse has N-dimensional sparse arrays implemented using COO or a multi-dimensional generalization of CSR. Similarly taco has routines for general N-dimensional sparse computation. That route may be better, particularly for some deep learning workflows that e.g. have 4D batched arrays of sparse image sequences.
  • XLA scatter/gather implementations only go so far… any operations we implement should have efficient lowerings on as many backends as possible - I’ve started with cusparse on GPU because the low-level routines are already available in our GPU build process. This means that every non-trivial operation should be either implemented as a primitive or composed of other primitives
  • JAX is about composability, and in particular pmap/vmap, jit, and grad are central to the API. I want to spend some time exploring how we can best implement gradients of sparse operations or batched sparse operations, and also think about how ongoing work on dynamic shapes can be utilized to make jit of sparse operations more efficient and usable. With all that up in the air I want to keep the design minimal for now in case we need to re-think our approach based on what we find
  • all of the above should be motivated by real-world workflows that people have. I’m starting to gather a few example workflows so we can have that in mind.

So I set out to compare the implementations in pytorch and tensorflow to get a feel for what’s out there, and considering the points above. Specifically, we’re interested in how they implement it low-level: do they secretly turn everything dense under the surface, or are they given the full treatment with specific ops?

I’m summarizing what I found in the table below - TL;DR Pytorch offers better support. Since documentation and support are all over the place, in this notebook you can find a bunch of common operations with sparse tensor to see if they are supported. It’s been a while since I did tensorflow, so let me know if I made a mistake!

Feature Pytorch Tensorflow
Documentation here here
Formats COO COO
Supported ops variations of matmul sums, min, matmul, element-wise ops through tf.sparse.map_values
Specialized ops Probably[^5] Seems so [^4]
Grad of sparse[^1] Limited to specific matmul No.
Sparse grad[^2] Yes No.
Dimensions 2D, Hybrid 2D - tensor supports higher-D, but operations not
Uncoalesced[^3]allowed Yes No mention
Sparse - sparse matmul No No
Extra goodies Has an adam optimizer for sparse tensor Nope

[^1]: Meaning gradient w.r.t to the sparse tensor [^2]: Meaning calculating the gradient without turning it into a dense tensor [^3]: Uncoalesced means duplicate entries in the coordinates - total value of the element is then sum of duplicate elements. [^4]: Sparse ops are imported from gen_sparse_ops.py, which doesnt exist in the repo? [^5]: They mention Torchcudasparse is now defunct here but I can’t find any other mention in the source code.

To do

  • Dig in the source code a little further (although if someone who actually is familiar with the code could chime in, that’d be great.)
  • Check how pydata-sparse implement n-D sparse arrays.

Proposal:

  • Get a very minimal differentiable sparse matmul working for COO format using XLA ops and one GPU-specific backend in cusparse (some help and guidance here would be nice)

I figured we can use this issue to track progress and discuss, so … thoughts?

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:7
  • Comments:19 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Sep 29, 2021

I don’t have any concrete plans for that, but it is definitely a possibility. In particular, cuSolver could be a compelling project because jaxlib already depends on CUDA, so it wouldn’t require any new dependencies.

One related TODO is to lower BCOO routines to relevant cuSparse calls on the GPU; I haven’t done this yet because I’m not entirely sure how to handle batched calls on the GPU.

0reactions
jakevdpcommented, Oct 18, 2021

I’m going to close this in favor of #765 to avoid duplication of discussion.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.experimental.sparse module - JAX documentation
The jax.experimental.sparse module includes experimental support for sparse matrix operations in JAX. It is under active development, and the API is subject ...
Read more >
Sparse Matrix Multiplication in JAX - Guillem Cucurull
Unfortunately, JAX doesn't yet support sparse matrices as well as other libraries like Pytorch, Tensorflow or Numpy (via Scipy), ...
Read more >
klujax - PyPI
KLUJAX. A sparse linear solver for JAX based on the efficient KLU algorithm. CPU & float64. This library is a wrapper around the...
Read more >
Indexing a BCOO in Jax - python - Stack Overflow
Thanks for the question. Unfortunately, general indexing support has not been added yet to jax.experimental.sparse . The types of indexing ...
Read more >
Compiler Support for Sparse Tensor Computations in MLIR
Using a different sparse attribute for matrix A , or adding sparse attributes to the tensor types of matrices B , C ,...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found