Should Flax return FrozenDicts or regular dicts?
See original GitHub issueThis topic is discussed regularly internally, and I feel we haven’t reached a consensus here. Below are some arguments collected from users for both positions, feel free to add.
Arguments in favor of FrozenDict
- @avital: If you use normal dicts, it is easy to mutate them, which means the behavior may differ depending on whether the function in which the modification is made is jitted or not. Example:
def f(params):
params['conv1']['weight'] = ...
return ...some computation over params
params = load_from_checkpoint()
print(f(params))
# now what is the value of params['conv1']['weight']?
# depending on whether f is jitted or not, you'd get different results
Arguments in favor of regular dicts
-
@lucasb-eyer: Flax tells me “here’s these precious weights, please hold them for me and give them back to me later on, but DONT TOUCH” it begs the question: why give them to me in the first place, if I’m not supposed to do anything with it?
-
@avital: I also think it’d be better for Flax to return normal Python dicts, but still use FrozenDict within modules (via the
mutable
argument toapply
).
Issue Analytics
- State:
- Created 2 years ago
- Comments:24 (15 by maintainers)
Top Results From Across the Web
flax.core.frozen_dict package - Read the Docs
Makes a mutable copy of a FrozenDict mutable by transforming it into (nested) dict. Parameters. x – Frozen dictionary to unfreeze. Returns. The...
Read more >python - What would a "frozen dict" be?
No relying on the good behavior of future users and developers. It's easy to convert back and forth between a regular dictionary and...
Read more >Source code for transformers.modeling_flax_utils
_module @property def params(self) -> Union[Dict, FrozenDict]: return self. ... model configuration should be cached if the standard cache should not be ...
Read more >hailtop.frozendict - Hail
from typing import TypeVar, Dict, Generic from collections.abc import ... my_frozen_dict = hl.utils.frozendict({1:2, 7:5}) To get a normal ...
Read more >From PyTorch to JAX: towards neural net frameworks that ...
Update 2021-07-01: I gave a talk at the Flax/JAX community week ... But going back and forth between our object and the params...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Sorry for the delay – I was on parental leave.
@jheek could you tell us whether any progress has been made on merging the chex and flax dataclasses?
Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). It’s a hard issue to fix though because mutability “infects” all of your code and Python isn’t a functional language.
That said, I don’t think FrozenDict has shown to be a very effective tool safety tool to avoid this kind of error. We should probably keep using it internally to avoid accidental reference sharing but for users it seems to big a burden while it doesn’t avoid the more common issue of closing over mutable state (typically created by the user) or using things like np.random in a jitted function.
I do think we should at least provide an easy way to clone a pytree if we allow it to contain mutable containers. Something like the following:
Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.