question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Disable gradients with special type

See original GitHub issue

I have difficulties with structures that have non-differentiable parts. This causes a lot of issues in my current design. E.g. I want a differentiable structure that has parts that are discrete.

I decided to oversmart JAX and make an object that behaves like a scalar (an array) and pass it to the initialiser of DifferantiableView. But, it gives me a different issue related to tracking gradients and allowed types for tracking.

import typing
import jax.numpy as jnp

class NonDifferentiable(int):
    pass

class DifferentiableView(typing.NamedTuple):
    parameter: jnp.array
    num_data: int

def func(s):
    return jnp.exp(s.parameter) ** 2 / s.num_data

s = DifferentiableView(2.0, NonDifferentiable(100))
grad_func = jax.grad(func)(s)

Gives an error:

TypeError: <class '__main__.NonDifferentiable'> is not a valid Jax type

I think it would be pretty cool to have such recognisable class for labelling (wrapping) scalars and shaped arrays. For JAX that would be a signal to stop tracing gradients for a variable.

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:3
  • Comments:9 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
awavcommented, Jun 6, 2020

My feeling is that this isn’t JAX’s problem, but rather a design problem on the user side.

I disagree. The very first question when I started looking at JAX and recommended it to others was: “How do I stop gradient propagation without changing the code?”. This is a UX question not a design issue on the user’s side, https://twitter.com/srush_nlp/status/1260583364102434817?s=20.

The point is that the differentiable parts are not static. A user may decide change differentiation path w.r.t a structure. Therefore there is no way to know in advance how to split the object into differentiable and non-differentiable parts, especially when you have the ability to extend that differentiable structure, the number of parameters is unknown, and the underlying objective algorithm is the same “regardless” of parameters.

Here is an example (apologies this is almost a duplication of what I showed before):

@dataclass(frozen=True)
class Kernel:
    variance: float

@dataclass(frozen=True)
class SquaredExponential(Kernel):
    lenthscale: float

@dataclass(frozen=True):
class UnkownKernelWithAdditional100HyperParameters(Kernel):
    ...

def loss(kernel: Kernel) -> jnp.ndarray:
    pass # Compute marginal likelihood of the Gaussian process here

Often we will need to compute gradients w.r.t. lengthscale or variance only and work with the same code that uses that kernel. I propose to introduce a proxy type that will stop gradient computation for plain types, and it will allow successful interaction with other structures:

# hypothetical proxy class
class NonDifferentiable:
    def __init__(self, wrap_this_object):
          ...

a = jnp.zeros(10)
non_differentiable_a = NonDifferentiable(a)
a2 = non_differentiable_a + a    # Still works!

Therefore we could do this:

k1 = Kernel(1.0)
k2 = SquaredExponential(NonDifferentiable(1.0), 1.0)

jax.grad(loss)(k1)
jax.grad(loss)(k2)

After all, you might choose to differentiate with respect to one thing in one context and another thing in another context.

I’m sorry, I don’t understand your example here. Can you elaborate? Thanks!

then you can never have a tracer in their place

I wouldn’t say never. As far as I understand nobody has tried to implement it or fix it if that’s a bug. Also, it doesn’t mean that that level of flexibility cannot and should not be provided to a user. TensorFlow and PyTorch have this naive and simple feature since the beginning. In turn, JAX can give a hybrid solution to trainable arrays (structures) that will benefit lots of potential users, additionally to really cool features like vmap and pmap.

0reactions
awavcommented, Jun 8, 2020

I was talking about the problem with the solutions that were suggested to you in this thread whereby you use the hashable output of tree_flatten to mark nondifferentiable arguments. Maybe I don’t understand your problem, but if it’s that you have non-concrete, nondifferentiable arguments, then these solutions will not work in general since those arguments are treated like static arguments, and it doesn’t make sense for tracers to be treated as static arguments. They can’t be hashed, which is what the jitter needs to do in order to prevent recompilation. I think it would be nice if JAX were a bit more proactive about preventing you from accidentally sending tracers down that path, but otherwise, you’re in for a lot of debugging (what happened to me) if you let that happen. I was just trying to warn you to save you the pain I went through 😃

Thanks a lot! I thought that flatten can split up an object on static and non-static parts, and then build it up back. Therefore a hash can be computed with non-static objects only. Caveat: I don’t possess deep knowledge on this topic (yet).

I’m just curious, but what’s the problem with grad(loss)(kernel).variance? I think it’s only expensive to compile; computationally, it should really fast, but maybe one of the JAX devs could confirm.

In this particular case, it doesn’t matter. But, there are situations, when we could switch on/off parameters like covariance matrix of a Gaussian distribution (sparse Gaussian processes). Where the computation of gradients takes a route of Cholesky and other expensive operators.

If you don’t the simple solution, maybe you could propose a fancier interface for grad like: … Or else, what do you think of this more general solution? …

Both proposals are good. Although, I might have an example where this would not work (or can be not user friendly): compositional kernels like sum, product and hence their nested combinations.

compositional_kernel = kernel1 + kernel2 * kernel3   # all kernels have variance parameter

Writing extractors and replace function for such compositions would be a nightmare. But there might exist another design, better one, for kernel compositions.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How can I disable gradient updates for some modules in ...
This code works to minimize the execution of unnecessary graph traversals and gradient updates. I still need to refactor it for staggered ...
Read more >
no_grad — PyTorch 1.13 documentation
Context-manager that disabled gradient calculation. Disabling gradient calculation is useful for inference, when you are sure that you will not call Tensor.
Read more >
tf.stop_gradient | TensorFlow v2.11.0
Stops gradient computation. ... (if the max values are not unique then the gradient could flow to the wrong ... Has the same...
Read more >
Use gradients in Photoshop - Adobe Support
Select a gradient in the Gradients panel · Apply gradients to layers · Organize gradient presets into groups · Show legacy gradients ·...
Read more >
5 gradient/derivative related PyTorch functions - Attyuttam Saha
Disabling gradient calculation is useful for inference, ... It should be a tensor of matching type and location, that contains the gradient ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found