Best way to define gradient for linear black box function
See original GitHub issueDear jax team,
I’d like to use a black box function in jax with grad
where the function is a linear operator, i.e. a function f(x)
executing A.dot(x)
where A
is a matrix too large for memory so A
is computed on-the-fly.
I guess I could define its gradient with jax.custom_gradient
, but I assumed there are performance benefits in telling jax that the operation is linear? Either way, I would have to provide another black box function computing A.T.dot(x)
.
Is there a way in jax to do this? Or should I just provide the gradient rule via jax.custom_gradient
?
Issue Analytics
- State:
- Created 4 years ago
- Comments:15 (12 by maintainers)
Top Results From Across the Web
linear-gradient() - CSS: Cascading Style Sheets | MDN
The linear-gradient() CSS function creates an image consisting of a progressive transition between two or more colors along a straight line.
Read more >Black box function optimization - Isaac Leonard
Gradient descent. To find the lowest point on a loss surface, one technique which comes to mind is to simply place a ball...
Read more >How to find the gradient when a black box I/O function is ...
Now you can use gradient descent to train the neural network NN. You will use gradient descent to minimize the loss function. This...
Read more >Treating Linear Regression like More Than Just a Black Box
In a nutshell, Gradient Descent algorithm finds the optimum solution for a linear regression problem by tweaking model parameters θ (the ...
Read more >Explicit Gradient Learning for Black-Box Optimization - ICML
Definition : A Black-Box Function f : Ω → R, Ω ⊆ Rn is a Black-Box function if one can sample y =...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
The idea is make everything work based on black-box functions for computing matrix-vector products.
(We don’t have support for explicit sparse matrices in JAX yet.)
If you’re willing to define your function as a
core.Primitive
, you can set it up for AD usingad.deflinear(primitive, transpose_rule)
, wheretranspose_rule
is a function that behaves likeA.T.dot(x)
. This notebook may be helpful for elucidating the current (internal/unstable) API surface for primitives.