Disable gradients with special type
See original GitHub issueI have difficulties with structures that have non-differentiable parts. This causes a lot of issues in my current design. E.g. I want a differentiable structure that has parts that are discrete.
I decided to oversmart JAX and make an object that behaves like a scalar (an array) and pass it to the initialiser of DifferantiableView
. But, it gives me a different issue related to tracking gradients and allowed types for tracking.
import typing
import jax.numpy as jnp
class NonDifferentiable(int):
pass
class DifferentiableView(typing.NamedTuple):
parameter: jnp.array
num_data: int
def func(s):
return jnp.exp(s.parameter) ** 2 / s.num_data
s = DifferentiableView(2.0, NonDifferentiable(100))
grad_func = jax.grad(func)(s)
Gives an error:
TypeError: <class '__main__.NonDifferentiable'> is not a valid Jax type
I think it would be pretty cool to have such recognisable class for labelling (wrapping) scalars and shaped arrays. For JAX that would be a signal to stop tracing gradients for a variable.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:9 (4 by maintainers)
Top Results From Across the Web
How can I disable gradient updates for some modules in ...
This code works to minimize the execution of unnecessary graph traversals and gradient updates. I still need to refactor it for staggered ...
Read more >no_grad — PyTorch 1.13 documentation
Context-manager that disabled gradient calculation. Disabling gradient calculation is useful for inference, when you are sure that you will not call Tensor.
Read more >tf.stop_gradient | TensorFlow v2.11.0
Stops gradient computation. ... (if the max values are not unique then the gradient could flow to the wrong ... Has the same...
Read more >Use gradients in Photoshop - Adobe Support
Select a gradient in the Gradients panel · Apply gradients to layers · Organize gradient presets into groups · Show legacy gradients ·...
Read more >5 gradient/derivative related PyTorch functions - Attyuttam Saha
Disabling gradient calculation is useful for inference, ... It should be a tensor of matching type and location, that contains the gradient ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I disagree. The very first question when I started looking at JAX and recommended it to others was: “How do I stop gradient propagation without changing the code?”. This is a UX question not a design issue on the user’s side, https://twitter.com/srush_nlp/status/1260583364102434817?s=20.
The point is that the differentiable parts are not static. A user may decide change differentiation path w.r.t a structure. Therefore there is no way to know in advance how to split the object into differentiable and non-differentiable parts, especially when you have the ability to extend that differentiable structure, the number of parameters is unknown, and the underlying objective algorithm is the same “regardless” of parameters.
Here is an example (apologies this is almost a duplication of what I showed before):
Often we will need to compute gradients w.r.t. lengthscale or variance only and work with the same code that uses that kernel. I propose to introduce a proxy type that will stop gradient computation for plain types, and it will allow successful interaction with other structures:
Therefore we could do this:
I’m sorry, I don’t understand your example here. Can you elaborate? Thanks!
I wouldn’t say never. As far as I understand nobody has tried to implement it or fix it if that’s a bug. Also, it doesn’t mean that that level of flexibility cannot and should not be provided to a user. TensorFlow and PyTorch have this naive and simple feature since the beginning. In turn, JAX can give a hybrid solution to trainable arrays (structures) that will benefit lots of potential users, additionally to really cool features like
vmap
andpmap
.Thanks a lot! I thought that
flatten
can split up an object on static and non-static parts, and then build it up back. Therefore a hash can be computed with non-static objects only. Caveat: I don’t possess deep knowledge on this topic (yet).In this particular case, it doesn’t matter. But, there are situations, when we could switch on/off parameters like covariance matrix of a Gaussian distribution (sparse Gaussian processes). Where the computation of gradients takes a route of Cholesky and other expensive operators.
Both proposals are good. Although, I might have an example where this would not work (or can be not user friendly): compositional kernels like sum, product and hence their nested combinations.
Writing extractors and replace function for such compositions would be a nightmare. But there might exist another design, better one, for kernel compositions.