How to freeze parameters?
See original GitHub issueI’m looking for a way to freeze parameters in most efficient manner. One would be to take the current TrainingState
(I’m coming from flax) and manually partition state.params
like
raw_state: TrainingState = original_unfrozen_model
frozen_params, learnable_params = manually_partition_parameters(raw_state.params)
def frozen_apply_fn(params, *args, **kwargs):
return raw_state.apply_fn(
manually_merge_params(params, frozen_params),
*args,
**kwargs
)
partially_frozen_state = TrainingState(
params=learnable_params,
apply_fn=frozen_apply_fn
)
and then create a new optimizer for partially_frozen_state
. The downside of that approach is that it completely mutilates the structure of the pytree, making serialized states incompatible between the frozen and unfrozen variant, as well as requires manual splitting and merging of the parameters in a consistent manner (easy source of bugs).
An alternative would be to use optax.masked
to just not apply updates to certain parameters. My question however is if JAX (XLA) is smart enough to optimize the computation of gradients w.r.t. masked-out parameters out of the compute graph. This isn’t clear to me, because the general code flow is to create the neural network in flax, which returns the application function and a pytree of parameters, I then manually (without optax’s involvement) compute the gradients w.r.t. all parameters, and only then those are passed to optax, which may use them or not. If the gradients for frozen parameters are computed but just discarded, this solution would fail to preserve GPU memory, a major reason for freezing the parameters in the first place. If that works, I’m perfectly satisfied with that approach.
I am submitting this as an issue, rather than a discussion, because currently searching the docs for “freeze”/“frozen” yields no hits at all, so the topic is either not covered at all or insufficiently discoverable.
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
@mkunesch sorry for the late reply, it slipped through. Thanks a lot for your explanation. I agree with @rosshemsley that an example would be the best, but it should also be searchable in the API reference itself. I think modifying the docstring to say something like “this operation can be used to freeze parameters as: <insert one line example>” would be good, adding explicity function name aliases shouldn’t be necessary 😃
Hi! We improved the documentation of this in #299 and opened #296 to add a freezing example so I’ll close this issue. Thanks a lot for raising it!