question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to determine whether a Linen Module is still initializing?

See original GitHub issue

We currently check whether a module is initialized using self.has_variable, for instance in our definition of MultiHeadDotProductAttention:

...
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable('cache', 'cached_key')
...
if is_initialized:
  ...

I think this looks a bit strange to the average user. In the old API we were using self.is_initialized, and I think we should consider bringing this back in some form or another.

Below some remarks made by different users offline:

From @levskaya : “The reason we did it this way is that it’s strictly more general if you imagine having separate variable collections that might have a more complicated initialization than just a single-shot “init” function. That said, everyone seems to hate it, so I’m not sure the generality is worth not having an is_initializing module-global state variable…”

From @jheek : "You need to be careful with the definition of initializing. The simplest one is we called Module.init and not Module.apply. But there are other cases like if I call the same Module in init for the second time it’s already fully initialized so should self.is_initializing return True or False? Something like this:

def __call__(self, x):
  fc = Dense(5)
  x = fc(x) # this will init kernel, bias...
  x = fc(x) # already initialized. 

So to summarize: it seems useful to be able to have a boolean in Module specifying whether a module is initialized or not, but there are some tricky cases we should think about (see example above).

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:14 (2 by maintainers)

github_iconTop GitHub Comments

2reactions
levskayacommented, Jul 19, 2022

#2234 adds an explicit check method self.is_initializing for the use of nn.init(...)(...) or module.init(...).

1reaction
cgarciaecommented, May 11, 2022

@marksandler2 I was looking into this, BatchNorm defines initialization like this:

initializing = self.is_mutable_collection('params')

Which results in batch statistics not being updated if params are mutable because of:

if not initializing:
  ra_mean.value = self.momentum * ra_mean.value + (self.momentum) * mean
  ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var

@jheek from an outside perspective it’s not clear why making params mutable results in BatchNorm not training properly. If we want to avoid this issue maybe additional context information about whether we are inside init or not is necessary.

Read more comments on GitHub >

github_iconTop Results From Across the Web

flax.linen.module - Read the Docs
Args: col: the variable collection. name: the name of the variable. value: the new value of the variable. Returns: """ if self.scope is...
Read more >
Linen: Consider raising an error when reading ... - GitHub
Within a module that uses shape-inference (as most of the built-in Linen ... the variable you are trying to access is not initialized...
Read more >
Writing a Training Loop in JAX + FLAX - Wandb
linen, the neural network API that enables us to easily define neural network models in a flexible and pythonic manner (discussed below).  ......
Read more >
A guided tour of Flax - | notebook.community
Simplifying Neural Networks in JAX: Flax Modules ... returns only the shape and dtype of outputs but still creates fully initialized parameters. If...
Read more >
Flax 2 ("Linen") - Colaboratory - Google Colab
We support declaring modules in setup() that can still benefit from shape inference by using Lazy Initialization that sets up variables the first...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found