question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to freeze parameters?

See original GitHub issue

I’m looking for a way to freeze parameters in most efficient manner. One would be to take the current TrainingState (I’m coming from flax) and manually partition state.params like

raw_state: TrainingState = original_unfrozen_model
frozen_params, learnable_params = manually_partition_parameters(raw_state.params)

def frozen_apply_fn(params, *args, **kwargs):
  return raw_state.apply_fn(
    manually_merge_params(params, frozen_params),
    *args,
    **kwargs
  )

partially_frozen_state = TrainingState(
  params=learnable_params,
  apply_fn=frozen_apply_fn
)

and then create a new optimizer for partially_frozen_state. The downside of that approach is that it completely mutilates the structure of the pytree, making serialized states incompatible between the frozen and unfrozen variant, as well as requires manual splitting and merging of the parameters in a consistent manner (easy source of bugs).

An alternative would be to use optax.masked to just not apply updates to certain parameters. My question however is if JAX (XLA) is smart enough to optimize the computation of gradients w.r.t. masked-out parameters out of the compute graph. This isn’t clear to me, because the general code flow is to create the neural network in flax, which returns the application function and a pytree of parameters, I then manually (without optax’s involvement) compute the gradients w.r.t. all parameters, and only then those are passed to optax, which may use them or not. If the gradients for frozen parameters are computed but just discarded, this solution would fail to preserve GPU memory, a major reason for freezing the parameters in the first place. If that works, I’m perfectly satisfied with that approach.

I am submitting this as an issue, rather than a discussion, because currently searching the docs for “freeze”/“frozen” yields no hits at all, so the topic is either not covered at all or insufficiently discoverable.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
jatentakicommented, Feb 3, 2022

@mkunesch sorry for the late reply, it slipped through. Thanks a lot for your explanation. I agree with @rosshemsley that an example would be the best, but it should also be searchable in the API reference itself. I think modifying the docstring to say something like “this operation can be used to freeze parameters as: <insert one line example>” would be good, adding explicity function name aliases shouldn’t be necessary 😃

0reactions
mkuneschcommented, Feb 8, 2022

Hi! We improved the documentation of this in #299 and opened #296 to add a freezing example so I’ll close this issue. Thanks a lot for raising it!

Read more comments on GitHub >

github_iconTop Results From Across the Web

How the pytorch freeze network in some layers, only the rest of ...
For partially unfreezing some of the last layers, we can identify parameters we want to unfreeze in this loop. setting the flag to...
Read more >
How to freeze selected layers of a model in Pytorch?
afterwards to unfreeze parameters in (15). Loop from 15 to 18 to unfreeze the last several layers. Share.
Read more >
Pytorch freeze part of the layers - Jimmy (xiaoke) Shen - Medium
In PyTorch we can freeze the layer by setting the requires_grad to False. The weight freeze is helpful when we want to apply...
Read more >
How to freeze model parameters? - Google Groups
I'm wondering if there is a way to fix variable values of a model when training models like GAN? I know it's possible...
Read more >
Freezing parameters - Equinox - Patrick Kidger
In this example, we demonstrate how the filtering of equinox.filter_value_and_grad can be customised -- in this case, to only train some parameters and...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found