question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Feature request: ability to apply stop gradient to some parameters

See original GitHub issue

To motivate this feature request, I’ll explain what I’m currently doing (without Flax), and the other solutions I’ve considered. Then I’ll suggest some Flax solution.

Problem:

In the process of inferring one of my modules, I need to mask varying subsets of the weights with stop-gradient in a single function:

(click for long snippet)

def infer(encoding: EncodingElement,
          observation: PoolingMessage,
          prediction: PredictionMessage,
          rng: Generator,
          weights: FrozenVariableDict) -> TwoPassEncodingConfiguration:
    sampler_rng, code_rng = rng.split()
    # Create four copies of the weights:
    # * weights_sg has stop_gradient applied to all weights, and
    # * the other three have stop_gradient applied to different partitions of
    # the weights.
    weights_sg, weights_g, weights_c, weights_e = _stop_gradient_on_some_weights(weights)

    # Inference ------------------------------------------------------------------------------------
    # This function uses weights_sg so this calculation won't poison the weight cotangents.
    # However, cotangents still propagate back to observation.
    code_message = encoding.code_message(observation, weights_sg)

    # GLN loss -------------------------------------------------------------------------------------
    # The scan parameters depend on weights_g.
    encoding_parameters_g = SamplerParameters(observation, prediction, weights_g)
    # This use of stop_gradient prevents the cotangents from propagating back from the scan through
    # to the observation.
    initial_code_message = stop_gradient(code_message)
    # This class manages an iterated function (a scan)
    sampler = EncodingSampler(encoding)
    sampler_iterations = encoding.inference_parameters.sampler_iterations
    initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng)
    # This is an extremely computationally expensive scan.
    sampler_state, sampler_trajectory = sampler.sample_trajectory(
        encoding_parameters_g, initial_sampler_state, sampler_iterations, None)
    # We calculate a GLN loss, which can only affects the subset of weights in weights_g.
    gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss)
                / sampler_iterations)
    iterative_code_message = sampler_state.code_message

    # Code loss ------------------------------------------------------------------------------------
    # The code loss trains the code and selection links to produce a code message that predicts the
    # code message that we inferred by iteration.
    # This is the same code_message function as above, but uses weights_c.
    c_code_message = encoding.code_message(observation, weights_c, rng=code_rng,
                                           use_code_signal_noise=True)
    # When this loss is minimized only the weights that are not marked stop-gradient in weights_c
    # are adjusted.  Cotangents are also blocked from poisoning the scan by applying stop_gradient
    # to its outputs.
    code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence)
                                            - c_code_message.log_presence))
    code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value)
                                         - c_code_message.code_value))
    code_loss = code_presence_loss + code_value_loss

    # Snipped a lot of code here that uses weights_e and produces output primals.

    return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss)

# Below is the code that uses Haiku to partition the weights and apply stop gradient to different
# partitions.
_module_classes = [{'gln'}, {'code_value', 'code_presence'}, {'explanation'}]

def _module_predicate(module_name: str,
                      name: str,
                      value: Array) -> int:
    prefix = module_name.split('/')[0]
    for i, prefix_set in enumerate(_module_classes):
        if prefix in prefix_set:
            return i
    raise RuntimeError

# I was using Haiku before, but I'll have to port this to Flax somehow.
def _partition_by_module(weights: FrozenVariableDict) -> tuple[FrozenVariableDict, ...]:
    return hk.data_structures.partition_n(_module_predicate,  # type: ignore[arg-type]
                                          weights, len(_module_classes))

def _stop_gradient_on_some_weights(weights: FrozenVariableDict) -> list[FrozenVariableDict]:
    weights_sg = stop_gradient(weights)
    weights_p = _partition_by_module(weights)
    weights_sg_p = _partition_by_module(weights_sg)

    return ([weights_sg]
            + [hk.data_structures.merge(weights_pi,
                                        *[weights_sg_pi
                                          for j, weights_sg_pi in enumerate(weights_sg_p)
                                          if i != j])
               for i, weights_pi in enumerate(weights_p)])

Non-solution:

I discussed this with @cgarciae and brainstormed a non-solution: I could try to put the “C”, “G”, and “E” weights into different “collections”. And then run inference three times. This doesn’t work because:

  • the scan is very expensive and I don’t want to run it three times,
  • the scan can’t easily be hoisted out because the scan itself depends on the G collection of weights, and
  • all sorts of intermediate values that are created between the different parts of the function.

Possible Flax interface:

We came up with two Flax interfaces that might work.

I suggested some kind of context manager flax.linen.stop_gradient:

(click for long snippet)

def infer(encoding: EncodingElement,
          observation: PoolingMessage,
          prediction: PredictionMessage,
          rng: Generator,
          weights: FrozenVariableDict) -> TwoPassEncodingConfiguration:
    sampler_rng, code_rng = rng.split()

    # Inference ------------------------------------------------------------------------------------
    # This function uses weights_sg so this calculation won't poison the weight cotangents.
    # However, cotangents still propagate back to observation.
    with nn.stop_gradient(lambda c: True):
        code_message = encoding.code_message(observation)

    # GLN loss -------------------------------------------------------------------------------------
    encoding_parameters_g = SamplerParameters(observation, prediction)
    # This use of stop_gradient prevents the cotangents from propagating back from the scan through
    # to the observation.
    initial_code_message = stop_gradient(code_message)
    sampler = EncodingSampler(encoding)
    sampler_iterations = encoding.inference_parameters.sampler_iterations
    initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng)
    # The scan parameters depend on weights_g.
    with nn.stop_gradient(lambda c: c.name.starts_with('gln')):
        # This class manages an iterated function (a scan)
        # This is an extremely computationally expensive scan.
        sampler_state, sampler_trajectory = sampler.sample_trajectory(
            encoding_parameters_g, initial_sampler_state, sampler_iterations, None)
    # We calculate a GLN loss, which can only affects the subset of weights in weights_g.
    gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss)
                / sampler_iterations)
    iterative_code_message = sampler_state.code_message

    # Code loss ------------------------------------------------------------------------------------
    # The code loss trains the code and selection links to produce a code message that predicts the
    # code message that we inferred by iteration.
    # This is the same code_message function as above, but uses weights_c.
    with nn.stop_gradient(lambda c: c.name.starts_with('code')):
        c_code_message = encoding.code_message(observation, rng=code_rng,
                                               use_code_signal_noise=True)
    # When this loss is minimized only the weights that are not marked stop-gradient in weights_c
    # are adjusted.  Cotangents are also blocked from poisoning the scan by applying stop_gradient
    # to its outputs.
    code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence)
                                            - c_code_message.log_presence))
    code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value)
                                         - c_code_message.code_value))
    code_loss = code_presence_loss + code_value_loss

    # Snipped a lot of code here that uses weights_e and produces output primals.

    return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss)

Cristian suggested a lifting transformation like those found in flax.core.lift. I’m still learning how these work, so I can’t yet sketch what this might look like.

Possible side benefits

Besides applying stop-gradient, this kind of system may be able to do other things with parameters such as:

  • marking parameters as constant within a block, and raising if any computation tries to change them,
  • temporarily replacing parameters with values from another variable within a computation, or
  • replacing parameter cotangents with values from another variable or another parameter cotangent.

Of course, that’s beyond this feature request, but I mention these ideas as something to keep in mind when considering solutions.

Conclusion

Am I missing an easy solution to my problem? If not, I will need to solve this problem in order to use Flax since this use of stop-gradient is integral to my research. Thanks for reading!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:7 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
jheekcommented, Feb 15, 2022

Here’s an sketch of what that would look like:

from flax import traverse_util

 def selective_stop_grad(variables):
      flat_vars = traverse_util.flatten_dict(variables)
      new_vars = {k: lax.stop_gradient(v) if some_filter_fn(k) else v for k, v in flat_vars.items()}
      return traverse_util.unflatten_dict(new_vars)


class MySGModule(nn.Module):
  @nn.compact
  def __call__(self, x):
    MySGSubModule = nn.map_variables(MySubModule, "params", selective_stop_grad, init=True)
    return MySGSubModule(...)(x)
0reactions
NeilGirdharcommented, Feb 17, 2022

@jheek Thanks, that gets it to run, but it’s still not reflecting a copy of x’s parameters? It outputs:

[-1.5530705 -0.6934959  0.9631546] [ 0.246286    0.83799624 -0.91129684]
FrozenDict
    params=FrozenDict
        stop_gradient_all=FrozenDict
            submodule=FrozenDict
                dense=FrozenDict
                    bias=Jax Array (3,) float32
                            0.0000      0.0000      0.0000
                    kernel=Jax Array (3, 3) float32
                            0.3932      0.3981     -0.5165
                            0.1566     -0.0768     -0.2396
                           -0.3035      0.5167     -0.1552
        x=FrozenDict
            dense=FrozenDict
                bias=Jax Array (3,) float32
                        0.0000      0.0000      0.0000
                kernel=Jax Array (3, 3) float32
                       -0.0201     -0.6220      0.9425
                       -0.2652     -0.1386      0.8165
                       -1.2677      0.0672     -0.7958

So x and stop_gradient_all are different. Any idea how I can make it a mirror? I realize I nee to pass x somehow, but I’m still not sure how.

Read more comments on GitHub >

github_iconTop Results From Across the Web

ability to apply stop gradient to some parameters · Discussion ...
To motivate this feature request, I'll explain what I'm currently doing (without Flax), and the other solutions I've considered. Then I'll suggest some...
Read more >
How to stop gradient for some entry of a tensor in tensorflow
So the idea is to use mask and tf.stop_gradient to crack this problem: res_matrix = tf.stop_gradient(mask_h*E) + mask*E ,.
Read more >
tf.stop_gradient | TensorFlow v2.11.0
When building ops to compute gradients, this op prevents the contribution of its inputs to be taken into account. Normally, the gradient ......
Read more >
DevTools Feature Request: Editor for CSS color gradients
Issue 399582: DevTools Feature Request: Editor for CSS color gradients ... As a user, I want to pick the gradient type, click to...
Read more >
How to Control Your XGBoost Model | Capital One
XGBoost is a powerful gradient boosting tool for machine learning models, ... cuts in single features (unlike H2O and some other packages, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found