Jax support for sparse matrices
See original GitHub issueHi all,
I was discussing with @jakevdp in #6466 about adding support for sparse matrices in Jax. He already did quite some groundwork, but there’s a lot of open questions on the what and how. Quoting Jake here:
- 2D CSC/CSR/COO may not be the abstraction we want in the end… pydata-sparse has N-dimensional sparse arrays implemented using COO or a multi-dimensional generalization of CSR. Similarly taco has routines for general N-dimensional sparse computation. That route may be better, particularly for some deep learning workflows that e.g. have 4D batched arrays of sparse image sequences.
- XLA scatter/gather implementations only go so far… any operations we implement should have efficient lowerings on as many backends as possible - I’ve started with cusparse on GPU because the low-level routines are already available in our GPU build process. This means that every non-trivial operation should be either implemented as a primitive or composed of other primitives
- JAX is about composability, and in particular pmap/vmap, jit, and grad are central to the API. I want to spend some time exploring how we can best implement gradients of sparse operations or batched sparse operations, and also think about how ongoing work on dynamic shapes can be utilized to make jit of sparse operations more efficient and usable. With all that up in the air I want to keep the design minimal for now in case we need to re-think our approach based on what we find
- all of the above should be motivated by real-world workflows that people have. I’m starting to gather a few example workflows so we can have that in mind.
So I set out to compare the implementations in pytorch and tensorflow to get a feel for what’s out there, and considering the points above. Specifically, we’re interested in how they implement it low-level: do they secretly turn everything dense under the surface, or are they given the full treatment with specific ops?
I’m summarizing what I found in the table below - TL;DR Pytorch offers better support. Since documentation and support are all over the place, in this notebook you can find a bunch of common operations with sparse tensor to see if they are supported. It’s been a while since I did tensorflow, so let me know if I made a mistake!
Feature | Pytorch | Tensorflow |
---|---|---|
Documentation | here | here |
Formats | COO | COO |
Supported ops | variations of matmul | sums, min, matmul, element-wise ops through tf.sparse.map_values |
Specialized ops | Probably[^5] | Seems so [^4] |
Grad of sparse[^1] | Limited to specific matmul | No. |
Sparse grad[^2] | Yes | No. |
Dimensions | 2D, Hybrid | 2D - tensor supports higher-D, but operations not |
Uncoalesced[^3]allowed | Yes | No mention |
Sparse - sparse matmul | No | No |
Extra goodies | Has an adam optimizer for sparse tensor | Nope |
[^1]: Meaning gradient w.r.t to the sparse tensor [^2]: Meaning calculating the gradient without turning it into a dense tensor [^3]: Uncoalesced means duplicate entries in the coordinates - total value of the element is then sum of duplicate elements. [^4]: Sparse ops are imported from gen_sparse_ops.py, which doesnt exist in the repo? [^5]: They mention Torchcudasparse is now defunct here but I can’t find any other mention in the source code.
To do
- Dig in the source code a little further (although if someone who actually is familiar with the code could chime in, that’d be great.)
- Check how pydata-sparse implement n-D sparse arrays.
Proposal:
- Get a very minimal differentiable sparse matmul working for COO format using XLA ops and one GPU-specific backend in cusparse (some help and guidance here would be nice)
I figured we can use this issue to track progress and discuss, so … thoughts?
Issue Analytics
- State:
- Created 2 years ago
- Reactions:7
- Comments:19 (7 by maintainers)
Top GitHub Comments
I don’t have any concrete plans for that, but it is definitely a possibility. In particular, cuSolver could be a compelling project because jaxlib already depends on CUDA, so it wouldn’t require any new dependencies.
One related TODO is to lower BCOO routines to relevant cuSparse calls on the GPU; I haven’t done this yet because I’m not entirely sure how to handle batched calls on the GPU.
I’m going to close this in favor of #765 to avoid duplication of discussion.