Differentiable matrix-free linear algebra, optimization and equation solving
See original GitHub issueThis is a meta-issue for keeping track of progress on implementing differentiable higher-order functions from SciPy in JAX, e.g.,
-
scipy.sparse.linalg.gmres
andcg
: matrix-free linear solves -
scipy.sparse.linalg.eigs
andeigsh
: matrix-free eigenvalue problems -
scipy.optimize.root
: nonlinear equation solving -
scipy.optimize.fixed_point
: solving for fixed points -
scipy.integrate.odeint
: solving ordinary differential equations -
scipy.optimize.minimize
: nonlinear minimization
These higher-order functions are important for implementing sophisticated differentiable programs, both for scientific applications and for machine learning.
Implementations should leverage and build upon JAX’s custom transformation capabilities. For example, scipy.optimize.root
should leverage autodiff for calculating the Jacobians or Jacobian-vector products needed for Newton’s method.
In most cases, I think the right way to do this involves two separate steps, which could happens in parallel:
- Higher order primitives for defining automatic differentiation rules, but not specialized to any particular algorithm, e.g.,
lax.custom_linear_solve
from https://github.com/google/jax/pull/1402. - Implementations of particular algorithms for the forward problems, e.g., a conjugate gradient method for linear solves. These could either be implemented from scratch using JAX’s functional control flow (e.g.,
while_loop
) or could leverage existing external implementations on particular backends. Either way they will almost certainly need custom derivative rules, rather than differentiation through the forward algorithm.
There’s lots of work to be done here, so please comment if you’re interested in using or implementing any of these.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:12
- Comments:38 (30 by maintainers)
Top Results From Across the Web
Linear Algebra - Khan Academy
Learn linear algebra for free—vectors, matrices, transformations, and more.
Read more >Distributed Computation of Linear Matrix Equations - arXiv
This paper investigates the distributed computation of the well-known linear matrix equation in the form of AXB = F, with the matrices A,...
Read more >Algebra, Topology, Differential Calculus, and Optimization ...
4.2 Composition of Linear Maps and Matrix Multiplication . . . . . . . . . . . 116 ... 7.5 Systems...
Read more >Differential Equations - Review : Matrices & Vectors
The main topic from linear algebra that you must know however if you are going to be able to solve systems of differential...
Read more >Matrix-free methods - Wikipedia
In computational mathematics, a matrix-free method is an algorithm for solving a linear system of equations or an eigenvalue problem that does not...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@romanodev this is really awesome work.
@shoyer thanks for sharing! I think it would be nice to combine your implementation with the dot product between a sparse matrix and a vector #3717. The jit/GPU implementation still can’t beat Scipy and I suspect this is due to the COO representation of the sparse matrix (Scipy uses CSR https://github.com/scipy/scipy/blob/v1.5.1/scipy/sparse/base.py#L532). I will do some testing in this direction first.