question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Should Flax return FrozenDicts or regular dicts?

See original GitHub issue

This topic is discussed regularly internally, and I feel we haven’t reached a consensus here. Below are some arguments collected from users for both positions, feel free to add.

Arguments in favor of FrozenDict

  • @avital: If you use normal dicts, it is easy to mutate them, which means the behavior may differ depending on whether the function in which the modification is made is jitted or not. Example:
def f(params):
  params['conv1']['weight'] = ...
  return ...some computation over params

params = load_from_checkpoint()
print(f(params))
# now what is the value of params['conv1']['weight']?
# depending on whether f is jitted or not, you'd get different results

Arguments in favor of regular dicts

  • @lucasb-eyer: Flax tells me “here’s these precious weights, please hold them for me and give them back to me later on, but DONT TOUCH” it begs the question: why give them to me in the first place, if I’m not supposed to do anything with it?

  • @avital: I also think it’d be better for Flax to return normal Python dicts, but still use FrozenDict within modules (via the mutable argument to apply).

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:24 (15 by maintainers)

github_iconTop GitHub Comments

3reactions
marcvanzeecommented, Sep 6, 2021

Sorry for the delay – I was on parental leave.

@jheek could you tell us whether any progress has been made on merging the chex and flax dataclasses?

3reactions
jheekcommented, Apr 13, 2021

I think the Python saying “We’re all consenting adults here” is pretty fitting

Hidden state is notoriously hard to reason about and I think all ML frameworks are struggling with it currently. See for example the widely made mistake of mixing the default numpy rng and multiprocessing (link). It’s a hard issue to fix though because mutability “infects” all of your code and Python isn’t a functional language.

That said, I don’t think FrozenDict has shown to be a very effective tool safety tool to avoid this kind of error. We should probably keep using it internally to avoid accidental reference sharing but for users it seems to big a burden while it doesn’t avoid the more common issue of closing over mutable state (typically created by the user) or using things like np.random in a jitted function.

I do think we should at least provide an easy way to clone a pytree if we allow it to contain mutable containers. Something like the following:

def clone_pytree(xs):
  # cloning is just an identity mapping
  return jax.tree_map(lambda x: x, xs)

def some_nested_transformation():
  my_copy = flax.traverse_util.clone_pytree(variables)
  my_copy['batch_stats']['x'] += 2.
  return my_copy

Also we want to merge the chex and flax dataclass implementation. The most important difference is that chex dataclasses are mutable by default. I think we should keep the behaviour consistent so ideally we would make these changes together.

Read more comments on GitHub >

github_iconTop Results From Across the Web

flax.core.frozen_dict package - Read the Docs
Makes a mutable copy of a FrozenDict mutable by transforming it into (nested) dict. Parameters. x – Frozen dictionary to unfreeze. Returns. The...
Read more >
python - What would a "frozen dict" be?
No relying on the good behavior of future users and developers. It's easy to convert back and forth between a regular dictionary and...
Read more >
Source code for transformers.modeling_flax_utils
_module @property def params(self) -> Union[Dict, FrozenDict]: return self. ... model configuration should be cached if the standard cache should not be ...
Read more >
hailtop.frozendict - Hail
from typing import TypeVar, Dict, Generic from collections.abc import ... my_frozen_dict = hl.utils.frozendict({1:2, 7:5}) To get a normal ...
Read more >
From PyTorch to JAX: towards neural net frameworks that ...
Update 2021-07-01: I gave a talk at the Flax/JAX community week ... But going back and forth between our object and the params...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found