question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

FLIP: Make module instances semantically meaningful by not overriding `module.__new__`

See original GitHub issue

Introduction

Currently, while Flax modules are defined by subclassing flax.nn.Module, those modules don’t behave the same way that normal Python objects behave.

One of the large differences is that Flax Modules override __new__, meaning that module instances aren’t a semantically meaningful thing in Flax at the moment. Right now, in Flax, what looks like module construction (nn.Dense(x, features=10)) actually does two things:

  1. Construct an object of type nn.Dense (using the non-documented API module.new_instance())
  2. Call the apply method on that instance and return it.

Some upsides of the current approach are:

  1. Modules are defined as a single function, as opposed to, e.g. the style of other libraries, such as Haiku, where you need to scroll up and down between __init__ and __call__ to fully understand what a module does.
  2. Calls to submodules are very concise, e.g. nn.Dense(x, features=10).

Some downsides of the current approach are:

  1. In order to reuse a module, you must use the module.shared() abstraction which has a confusing mental model – what does module.shared() return? A module class? A module instance? Moreover, which arguments must be passed into module.shared() in order for the shared module to be usable? (Behind the scenes shared is implemented on top of partial)
  2. You can’t instantiate a module directly, outside of another module. This leads to surprising things like new nn.Model(nn.Dense.partial(features=10), params) – why do we need to use partial to instantiate a Model? What type does the first argument to nn.Model have? Is it a module class? Module instance?
  3. In a few spots in flax/nn/base.py there is code that does “kwarg mangling”. Some of these code had bugs before. It would be nice to reduce the need for kwarg mangling.
  4. In order to support multiple methods on a module, the module_method decorator turns methods that aren’t apply into new Modules. This is surprising, for example how would I do the equivalent of module.call(params, *args) but to call a method foo that’s not apply? That would be module.foo.call(params, *args). That’s a pretty surprising mental model.
  5. Wanting to define shared parameters or submodules that work across multiple methods requires either using non-traditional patterns and/or with more complexity in Flax core (see discussion on https://github.com/google/flax/issues/161)
  6. apply was a special-cased method on modules.

Proposal

  1. No longer override __new__ in Modules
  2. Eliminate .partial()
  3. Potentially eliminate .shared() (though we may choose to keep it as a safeguard – see below)
  4. Split up current module’s apply methods into the controlled use of Python 3.7 dataclasses (for storing module hyperparameters) and a “vanilla Python” __call__ method (or actually, any name you want) that only takes in the module input(s)
  5. This may even allow for module instance to directly refer to a read-only version of their parameters, e.g.:
class Foo(Module):
  def __init__(x):
    dense = nn.Dense(features=16)
    x = dense(x)
    # `dense.params` is defined here; maybe also `dense.params.kernel` and `dense.params.bias`

For example, a simple Dense layer may look like this:

@dataclass
class Dense(Module):
  features: int
  kernel_init: Callable = initializers.lecun_normal()
  bias_init: Callable = initializers.zeros

  def __call__(self, x):
    """Applies a linear transformation to the inputs along the last dimension."""
    kernel = self.param('kernel', (x.shape[-1], self.features), self.kernel_init)
    bias = self.param('bias', (self.features,), self.bias_init)
    return jnp.dot(x, kernel) + bias

Then, an MLP would look like this:

class MLP(Module):
  def __call__(self, x):
    x = nn.Dense(features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(features=16)(x)

I believe that this proposals keeps the conciseness of current Flax, while having the potential to significantly reduce both implementation complexity and mental model complexity. The mental model in Flax now reduces to the same one as Keras (other than the fact that parameters are immutable)

For example, in this case re-using a module is trivial – keep a reference to nn.Dense(features=16) and re-use that. (NOTE: We may choose to keep the safe-guarding behavior of .shared() that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when __call__ is invoked a second time, unless .shared() was called on the module instance first)

With this proposal, there’s also no need for module.partial – you can just use functools.partial(module.__call__) or functools.partial(module). (Though this is a bit different than in current Flax because the return value of functools.partial in itself isn’t a module, rather it’s a function. But maybe it was always confusing to understand module.partial – does it override kwargs for all module methods? Just apply?)

Possible transition plan

Given the non-trivial amount of code written using Flax, and the fact that this proposal would change every module written with Flax, we need an upgrade plan.

I propose adding, alongside every new module in flax.nn, a function with the same name but lower-cased, that operates the same as in current Flax. These functions would be deprecated-on-arrival. E.g., alongside Dense as shown above we would also have

def dense(x, features, kernel_init, bias_init):
  """DEPRECATED. Use the new Module API: http://link/to/upgrade/guide."""
  return Dense(features, kernel_init, bias_init)(x)

Then the first part of the upgrade process is mainly search and replace “Dense” -> “dense”, etc… In addition, some more manual changes will possible be necessary for uses of .partial and .shared. Later, users can transition into the new API at a time they see fit.

Current State

@avital has a messy work-in-progress branch checking the viability of using dataclasses in this settings. Results so far seem cautiously promising, but more work is needed before this proposal is ready to be acted on.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:18
  • Comments:32 (16 by maintainers)

github_iconTop GitHub Comments

3reactions
srushcommented, Apr 20, 2020

Nice, I like this change. It is a good start.

However, if you are making such a breaking change, this feels too conservative.

Core Issues:

  • This function still violates Pythonic conventions. nn.Dense is seemingly making mutable changes to some internal state buffer that is invisible to the user and not transparent in the syntax. (I know this happens in TF, but flax should be better.)
  def __call__(self, x):
    x = nn.Dense(features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(features=16)(x)

=> Does this mean?

  def __call__(self, x):
    x = nn.Dense(self, features=16)(x)
    x = nn.relu(x)
    x = nn.Dense(self, features=16)(x)

(Or alternatively pytorch / sonnet 2 syntax which both do this better)

  • Params are still treated differently than Layers, and use a string-based naming which seems dangerous and tempting for abuse.
bias = self.param('bias', (self.features,), self.bias_init) 

=> ?

bias = nn.Param(self, (self.features,), self.bias_init)
3reactions
lucasb-eyercommented, Apr 18, 2020

💯 to this change. It aligns the mental model with TF2’s tf.Module and PyTorch’s nn.Module a lot more, and both of these have converged to where they are now after many years of mistakes, so this is a good thing.

(NOTE: We may choose to keep the safe-guarding behavior of .shared() that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when call is invoked a second time, unless .shared() was called on the module instance first)

Please don’t. Re-using an instance is the common, intuitive, friction-less way of sharing weights; this would just add annoying overhead for the sake of avoiding a mistake which, frankly, I have never encountered. An explicit :share method was how it was done in Torch7, and it was annoying and painful and does not exist anymore in PyTorch.

Regarding the __init__ vs __call__ separation, I don’t think that it makes good code impossible, so if someone creates a monster hydra code because of that, it’s probable the author’s fault, not the library’s. Using dataclass (or attr.s) for this is an interesting idea. However, usually what is done in __init__ is just normalizing of convenience of parameters, for example allowing filter-size to be passed as (3,3) or as 3, and then turning 3 into (3,3) in __init__, such that __call__ is cleaner to read, and really you can skip reading __init__ with that in mind. I think this is a good thing.

Finally, I think you can have an even more convincing example for modules which have more than just the obvious __call__, like the VAE example here which currently is not trivial to understand: I either have to do a lot of guess-work about FLAX internals, or go back and read the whole docs. Whereas after your proposal (and in PyTorch) it can be much more straightforward.

Read more comments on GitHub >

github_iconTop Results From Across the Web

16. Modules - Exploring JS
Having both named exports and a default export in a module ... main2.js ------ import MyClass from 'MyClass' ; const inst = new...
Read more >
Module 31 Flashcards - Quizlet
Study with Quizlet and memorize flashcards containing terms like memory, encoding, storage and more.
Read more >
Disabled modules are broken beyond repair so the ... - Drupal
Module B does not have an explicit dependency on module A or C, but the configuration ... Make significant changes to entities on...
Read more >
Logging Cookbook — Python 3.11.1 documentation
It is true for references to the same object; additionally, application code can define and configure a parent logger in one module and...
Read more >
Here
More advance methods of the synthesis procedure are not found in this tutorial. Create New Directory. Create a new directory in your home...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found