Feature request: ability to apply stop gradient to some parameters
See original GitHub issueTo motivate this feature request, I’ll explain what I’m currently doing (without Flax), and the other solutions I’ve considered. Then I’ll suggest some Flax solution.
Problem:
In the process of inferring one of my modules, I need to mask varying subsets of the weights with stop-gradient in a single function:
(click for long snippet)
def infer(encoding: EncodingElement,
observation: PoolingMessage,
prediction: PredictionMessage,
rng: Generator,
weights: FrozenVariableDict) -> TwoPassEncodingConfiguration:
sampler_rng, code_rng = rng.split()
# Create four copies of the weights:
# * weights_sg has stop_gradient applied to all weights, and
# * the other three have stop_gradient applied to different partitions of
# the weights.
weights_sg, weights_g, weights_c, weights_e = _stop_gradient_on_some_weights(weights)
# Inference ------------------------------------------------------------------------------------
# This function uses weights_sg so this calculation won't poison the weight cotangents.
# However, cotangents still propagate back to observation.
code_message = encoding.code_message(observation, weights_sg)
# GLN loss -------------------------------------------------------------------------------------
# The scan parameters depend on weights_g.
encoding_parameters_g = SamplerParameters(observation, prediction, weights_g)
# This use of stop_gradient prevents the cotangents from propagating back from the scan through
# to the observation.
initial_code_message = stop_gradient(code_message)
# This class manages an iterated function (a scan)
sampler = EncodingSampler(encoding)
sampler_iterations = encoding.inference_parameters.sampler_iterations
initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng)
# This is an extremely computationally expensive scan.
sampler_state, sampler_trajectory = sampler.sample_trajectory(
encoding_parameters_g, initial_sampler_state, sampler_iterations, None)
# We calculate a GLN loss, which can only affects the subset of weights in weights_g.
gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss)
/ sampler_iterations)
iterative_code_message = sampler_state.code_message
# Code loss ------------------------------------------------------------------------------------
# The code loss trains the code and selection links to produce a code message that predicts the
# code message that we inferred by iteration.
# This is the same code_message function as above, but uses weights_c.
c_code_message = encoding.code_message(observation, weights_c, rng=code_rng,
use_code_signal_noise=True)
# When this loss is minimized only the weights that are not marked stop-gradient in weights_c
# are adjusted. Cotangents are also blocked from poisoning the scan by applying stop_gradient
# to its outputs.
code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence)
- c_code_message.log_presence))
code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value)
- c_code_message.code_value))
code_loss = code_presence_loss + code_value_loss
# Snipped a lot of code here that uses weights_e and produces output primals.
return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss)
# Below is the code that uses Haiku to partition the weights and apply stop gradient to different
# partitions.
_module_classes = [{'gln'}, {'code_value', 'code_presence'}, {'explanation'}]
def _module_predicate(module_name: str,
name: str,
value: Array) -> int:
prefix = module_name.split('/')[0]
for i, prefix_set in enumerate(_module_classes):
if prefix in prefix_set:
return i
raise RuntimeError
# I was using Haiku before, but I'll have to port this to Flax somehow.
def _partition_by_module(weights: FrozenVariableDict) -> tuple[FrozenVariableDict, ...]:
return hk.data_structures.partition_n(_module_predicate, # type: ignore[arg-type]
weights, len(_module_classes))
def _stop_gradient_on_some_weights(weights: FrozenVariableDict) -> list[FrozenVariableDict]:
weights_sg = stop_gradient(weights)
weights_p = _partition_by_module(weights)
weights_sg_p = _partition_by_module(weights_sg)
return ([weights_sg]
+ [hk.data_structures.merge(weights_pi,
*[weights_sg_pi
for j, weights_sg_pi in enumerate(weights_sg_p)
if i != j])
for i, weights_pi in enumerate(weights_p)])
Non-solution:
I discussed this with @cgarciae and brainstormed a non-solution: I could try to put the “C”, “G”, and “E” weights into different “collections”. And then run inference three times. This doesn’t work because:
- the scan is very expensive and I don’t want to run it three times,
- the scan can’t easily be hoisted out because the scan itself depends on the G collection of weights, and
- all sorts of intermediate values that are created between the different parts of the function.
Possible Flax interface:
We came up with two Flax interfaces that might work.
I suggested some kind of context manager flax.linen.stop_gradient
:
(click for long snippet)
def infer(encoding: EncodingElement,
observation: PoolingMessage,
prediction: PredictionMessage,
rng: Generator,
weights: FrozenVariableDict) -> TwoPassEncodingConfiguration:
sampler_rng, code_rng = rng.split()
# Inference ------------------------------------------------------------------------------------
# This function uses weights_sg so this calculation won't poison the weight cotangents.
# However, cotangents still propagate back to observation.
with nn.stop_gradient(lambda c: True):
code_message = encoding.code_message(observation)
# GLN loss -------------------------------------------------------------------------------------
encoding_parameters_g = SamplerParameters(observation, prediction)
# This use of stop_gradient prevents the cotangents from propagating back from the scan through
# to the observation.
initial_code_message = stop_gradient(code_message)
sampler = EncodingSampler(encoding)
sampler_iterations = encoding.inference_parameters.sampler_iterations
initial_sampler_state = SamplerState.initial_state(encoding, initial_code_message, sampler_rng)
# The scan parameters depend on weights_g.
with nn.stop_gradient(lambda c: c.name.starts_with('gln')):
# This class manages an iterated function (a scan)
# This is an extremely computationally expensive scan.
sampler_state, sampler_trajectory = sampler.sample_trajectory(
encoding_parameters_g, initial_sampler_state, sampler_iterations, None)
# We calculate a GLN loss, which can only affects the subset of weights in weights_g.
gln_loss = ((sampler_state.total_gln_centering_loss + sampler_state.total_prediction_loss)
/ sampler_iterations)
iterative_code_message = sampler_state.code_message
# Code loss ------------------------------------------------------------------------------------
# The code loss trains the code and selection links to produce a code message that predicts the
# code message that we inferred by iteration.
# This is the same code_message function as above, but uses weights_c.
with nn.stop_gradient(lambda c: c.name.starts_with('code')):
c_code_message = encoding.code_message(observation, rng=code_rng,
use_code_signal_noise=True)
# When this loss is minimized only the weights that are not marked stop-gradient in weights_c
# are adjusted. Cotangents are also blocked from poisoning the scan by applying stop_gradient
# to its outputs.
code_presence_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.log_presence)
- c_code_message.log_presence))
code_value_loss = jnp.sum(jnp.square(stop_gradient(iterative_code_message.code_value)
- c_code_message.code_value))
code_loss = code_presence_loss + code_value_loss
# Snipped a lot of code here that uses weights_e and produces output primals.
return TwoPassEncodingConfiguration(iterative_code_message, gln_loss, code_loss)
Cristian suggested a lifting transformation like those found in flax.core.lift
. I’m still learning how these work, so I can’t yet sketch what this might look like.
Possible side benefits
Besides applying stop-gradient, this kind of system may be able to do other things with parameters such as:
- marking parameters as constant within a block, and raising if any computation tries to change them,
- temporarily replacing parameters with values from another variable within a computation, or
- replacing parameter cotangents with values from another variable or another parameter cotangent.
Of course, that’s beyond this feature request, but I mention these ideas as something to keep in mind when considering solutions.
Conclusion
Am I missing an easy solution to my problem? If not, I will need to solve this problem in order to use Flax since this use of stop-gradient is integral to my research. Thanks for reading!
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:7 (6 by maintainers)
Top GitHub Comments
Here’s an sketch of what that would look like:
@jheek Thanks, that gets it to run, but it’s still not reflecting a copy of x’s parameters? It outputs:
So
x
andstop_gradient_all
are different. Any idea how I can make it a mirror? I realize I nee to passx
somehow, but I’m still not sure how.