Psuedoinverse solves
See original GitHub issuexref https://github.com/google/jax/pull/2794
If we’re willing to define a “pseudo-inverse solve” (which as far as I can tell does not exist in NumPy or SciPy) for computing A⁺ b
rather than A⁺
directly, we can potentially speed-up gradients of pseudo-inverse solves by large factor by defining a custom gradient rule using similar tricks to those used by lax.custom_linear_solve
.
This will be most relevant for computation on CPUs (where matrix-multiplication is comparable to the cost of computing an SVD/eigen-decomposition) and where we only use a single right-hand-side vector.
Should we go ahead and add a helper function for this somewhere? Maybe jax.ops.linalg
?? If the performance gap is large enough, we can add a loud warning to the docstring for jnp.linalg.pinv
.
I suspect it simply doesn’t exist in NumPy/SciPy because there isn’t much to be gained for such a function if you’re only worrying about the forward solve.
EDIT NOTE: removed incorrect benchmark that only worked for invertible matrices.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:4
- Comments:8 (8 by maintainers)
By the way, should the solves just form
U, s, V = svd(a)
and then multiply byU @ (np.where(s > cutoff, np.divide(1, s), 0.) * V)
andV.T @ (np.where(s > cutoff, np.divide(1, s), 0.) * U.T)
respecively?Yes, that intuition looks about right to me, at least for inverses – you turn matrix-matrix multiplications into matrix-vector multiplications.
The JVP rules for
svd
andpinv
involve lots of dense matrix-matrix multiplication, which is slow (on CPU). When I profile things on a GPU,sgemm
operations only take ~25% of the runtime (the rest is inside the SVD), so the potential speed-up is much smaller. This makes sense because GPUs are faster for matmuls but not really faster for matrix decomposition.