FLIP: Make module instances semantically meaningful by not overriding `module.__new__`
See original GitHub issueIntroduction
Currently, while Flax modules are defined by subclassing flax.nn.Module
, those modules don’t behave the same way that normal Python objects behave.
One of the large differences is that Flax Modules override __new__
, meaning that module instances aren’t a semantically meaningful thing in Flax at the moment. Right now, in Flax, what looks like module construction (nn.Dense(x, features=10)
) actually does two things:
- Construct an object of type
nn.Dense
(using the non-documented APImodule.new_instance()
) - Call the
apply
method on that instance and return it.
Some upsides of the current approach are:
- Modules are defined as a single function, as opposed to, e.g. the style of other libraries, such as Haiku, where you need to scroll up and down between
__init__
and__call__
to fully understand what a module does. - Calls to submodules are very concise, e.g.
nn.Dense(x, features=10)
.
Some downsides of the current approach are:
- In order to reuse a module, you must use the
module.shared()
abstraction which has a confusing mental model – what doesmodule.shared()
return? A module class? A module instance? Moreover, which arguments must be passed intomodule.shared()
in order for the shared module to be usable? (Behind the scenesshared
is implemented on top ofpartial
) - You can’t instantiate a module directly, outside of another module. This leads to surprising things like
new nn.Model(nn.Dense.partial(features=10), params)
– why do we need to usepartial
to instantiate a Model? What type does the first argument tonn.Model
have? Is it a module class? Module instance? - In a few spots in
flax/nn/base.py
there is code that does “kwarg mangling”. Some of these code had bugs before. It would be nice to reduce the need for kwarg mangling. - In order to support multiple methods on a module, the
module_method
decorator turns methods that aren’tapply
into new Modules. This is surprising, for example how would I do the equivalent ofmodule.call(params, *args)
but to call a methodfoo
that’s notapply
? That would bemodule.foo.call(params, *args)
. That’s a pretty surprising mental model. - Wanting to define shared parameters or submodules that work across multiple methods requires either using non-traditional patterns and/or with more complexity in Flax core (see discussion on https://github.com/google/flax/issues/161)
apply
was a special-cased method on modules.
Proposal
- No longer override
__new__
in Modules - Eliminate
.partial()
- Potentially eliminate
.shared()
(though we may choose to keep it as a safeguard – see below) - Split up current module’s
apply
methods into the controlled use of Python 3.7 dataclasses (for storing module hyperparameters) and a “vanilla Python”__call__
method (or actually, any name you want) that only takes in the module input(s) - This may even allow for module instance to directly refer to a read-only version of their parameters, e.g.:
class Foo(Module):
def __init__(x):
dense = nn.Dense(features=16)
x = dense(x)
# `dense.params` is defined here; maybe also `dense.params.kernel` and `dense.params.bias`
For example, a simple Dense layer may look like this:
@dataclass
class Dense(Module):
features: int
kernel_init: Callable = initializers.lecun_normal()
bias_init: Callable = initializers.zeros
def __call__(self, x):
"""Applies a linear transformation to the inputs along the last dimension."""
kernel = self.param('kernel', (x.shape[-1], self.features), self.kernel_init)
bias = self.param('bias', (self.features,), self.bias_init)
return jnp.dot(x, kernel) + bias
Then, an MLP would look like this:
class MLP(Module):
def __call__(self, x):
x = nn.Dense(features=16)(x)
x = nn.relu(x)
x = nn.Dense(features=16)(x)
I believe that this proposals keeps the conciseness of current Flax, while having the potential to significantly reduce both implementation complexity and mental model complexity. The mental model in Flax now reduces to the same one as Keras (other than the fact that parameters are immutable)
For example, in this case re-using a module is trivial – keep a reference to nn.Dense(features=16)
and re-use that. (NOTE: We may choose to keep the safe-guarding behavior of .shared()
that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when __call__
is invoked a second time, unless .shared()
was called on the module instance first)
With this proposal, there’s also no need for module.partial
– you can just use functools.partial(module.__call__)
or functools.partial(module)
. (Though this is a bit different than in current Flax because the return value of functools.partial
in itself isn’t a module, rather it’s a function. But maybe it was always confusing to understand module.partial
– does it override kwargs for all module methods? Just apply?)
Possible transition plan
Given the non-trivial amount of code written using Flax, and the fact that this proposal would change every module written with Flax, we need an upgrade plan.
I propose adding, alongside every new module in flax.nn
, a function with the same name but lower-cased, that operates the same as in current Flax. These functions would be deprecated-on-arrival. E.g., alongside Dense
as shown above we would also have
def dense(x, features, kernel_init, bias_init):
"""DEPRECATED. Use the new Module API: http://link/to/upgrade/guide."""
return Dense(features, kernel_init, bias_init)(x)
Then the first part of the upgrade process is mainly search and replace “Dense” -> “dense”, etc… In addition, some more manual changes will possible be necessary for uses of .partial
and .shared
. Later, users can transition into the new API at a time they see fit.
Current State
@avital has a messy work-in-progress branch checking the viability of using dataclasses in this settings. Results so far seem cautiously promising, but more work is needed before this proposal is ready to be acted on.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:18
- Comments:32 (16 by maintainers)
Top GitHub Comments
Nice, I like this change. It is a good start.
However, if you are making such a breaking change, this feels too conservative.
Core Issues:
nn.Dense
is seemingly making mutable changes to some internal state buffer that is invisible to the user and not transparent in the syntax. (I know this happens in TF, but flax should be better.)=> Does this mean?
(Or alternatively pytorch / sonnet 2 syntax which both do this better)
=> ?
💯 to this change. It aligns the mental model with TF2’s
tf.Module
and PyTorch’snn.Module
a lot more, and both of these have converged to where they are now after many years of mistakes, so this is a good thing.Please don’t. Re-using an instance is the common, intuitive, friction-less way of sharing weights; this would just add annoying overhead for the sake of avoiding a mistake which, frankly, I have never encountered. An explicit
:share
method was how it was done in Torch7, and it was annoying and painful and does not exist anymore in PyTorch.Regarding the
__init__
vs__call__
separation, I don’t think that it makes good code impossible, so if someone creates a monster hydra code because of that, it’s probable the author’s fault, not the library’s. Using dataclass (orattr.s
) for this is an interesting idea. However, usually what is done in__init__
is just normalizing of convenience of parameters, for example allowing filter-size to be passed as(3,3)
or as3
, and then turning3
into(3,3)
in__init__
, such that__call__
is cleaner to read, and really you can skip reading__init__
with that in mind. I think this is a good thing.Finally, I think you can have an even more convincing example for modules which have more than just the obvious
__call__
, like the VAE example here which currently is not trivial to understand: I either have to do a lot of guess-work about FLAX internals, or go back and read the whole docs. Whereas after your proposal (and in PyTorch) it can be much more straightforward.