FLIP: rm 'shape' from Module.param call signature
See original GitHub issueGoals
In our current implementation, initializer
methods (which get passed to Module.param
) are required to have the following signature:
def initializer(rng_key, shape):
# ...
return initialized_parameters
and in Module.param
we assert that initialized_parameters.shape == shape
.
Sometimes an initializer needs more (or less) information than the shape of its output, and at the moment this is achieved by writing the initializer function within a Module
definition so that it can close over other data that it requires. For example, weightnorm initialization, which is data-dependent, can be implemented as follows, note the necessity of the dummy shape arguments:
class Conv2D(nn.Module):
def apply(self, inputs, features, kernel_size):
strides = (inputs.ndim - 2) * (1,)
conv = partial(
lax.conv_general_dilated, window_strides=strides, padding='VALID',
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
in_features = inputs.shape[-1]
kernel_shape = kernel_size + (in_features, features)
def initializer(key, shape):
# A weightnorm initializer generating a (direction, scale, bias) tuple.
# Note that the shape argument is not used.
direction = nn.initializers.normal()(key, kernel_shape)
unnormed_out = conv(inputs, _l2_normalize(direction))
mean = np.mean(unnormed_out, (0, 1, 2))
var = np.std (unnormed_out, (0, 1, 2))
return dict(direction=direction, scale=1 / var, bias=-mean / var)
# We feed in None as a dummy shape argument to self.param. Currently
# Module.params assumes that the initializer takes in a shape argument;
# None acts as a flag to avoid shape checking.
params = self.param('weightnorm_params', None, initializer)
direction, scale, bias = [params[k] for k in ('direction', 'scale', 'bias')]
return conv(inputs, _make_kernel(direction, scale)) + bias
This situation isn’t terrible, but it does highlight the fact that the assumption that initializers depend on parameter shape and nothing else is a bit arbitrary.
A more flexible API, with initializer
requiring only a JAX PRNG key, would mean more consistent implementations of different types of initializers, and might also help to correctly emphasize what Flax’s role is here, namely to handle splitting and passing of PRNG keys to initializers and to setup the parameters data-structure (a nested dictionary).
Proposal
We propose to change the call signature of Module.param
from
param(self, name: str, shape: Shape, initializer: Callable[[Key, Shape], Array]):
to
param(self, name: str, initializer: Callable[[Key], Array]):
This change would lead to a slight simplification of the weightnorm example above, since the dummy shape arguments could be removed.
For existing Modules for which the initializer is a straightforward function of the parameter shape and no other data is required, we can alter the currying of the initializer definitions so that lines like
kernel = self.param('kernel', kernel_shape, kernel_init)
can be replaced by
kernel = self.param('kernel', kernel_init(kernel_shape))
There may be downsides to this approach which I, being a relative Flax noob, am unaware of. One thing is that we’d lose the shape checking in Module.params
, but that seems like the kind of check which should be part of a test anyway.
Alternatives
I think the obvious alternative is to simply keep the current API. The change proposed above is relatively minor but it still would likely require a number of users to make changes to their own code.
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (4 by maintainers)
Top GitHub Comments
I agree the shape argument is arbitrary. It also implies that we can only return a single tensor which is also arbitrary.
Note that Flax is using the initializers from
jax.nn.initializers
so that create some difficulty when we want to reuse those.An alternative that is less invasive is the following:
Change the current param api:
to:
This has been resolved (with an improved shape check) in #530