question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

FLIP: rm 'shape' from Module.param call signature

See original GitHub issue

Goals

In our current implementation, initializer methods (which get passed to Module.param) are required to have the following signature:

def initializer(rng_key, shape):
  # ...
  return initialized_parameters

and in Module.param we assert that initialized_parameters.shape == shape.

Sometimes an initializer needs more (or less) information than the shape of its output, and at the moment this is achieved by writing the initializer function within a Module definition so that it can close over other data that it requires. For example, weightnorm initialization, which is data-dependent, can be implemented as follows, note the necessity of the dummy shape arguments:

class Conv2D(nn.Module):
  def apply(self, inputs, features, kernel_size):
    strides = (inputs.ndim - 2) * (1,)

    conv = partial(
        lax.conv_general_dilated, window_strides=strides, padding='VALID',
        dimension_numbers=('NHWC', 'HWIO', 'NHWC'))

    in_features = inputs.shape[-1]
    kernel_shape = kernel_size + (in_features, features)

    def initializer(key, shape):
      # A weightnorm initializer generating a (direction, scale, bias) tuple.
      # Note that the shape argument is not used.
      direction = nn.initializers.normal()(key, kernel_shape)
      unnormed_out = conv(inputs, _l2_normalize(direction))
      mean = np.mean(unnormed_out, (0, 1, 2))
      var  = np.std (unnormed_out, (0, 1, 2))
      return dict(direction=direction, scale=1 / var, bias=-mean / var)

    # We feed in None as a dummy shape argument to self.param.  Currently
    # Module.params assumes that the initializer takes in a shape argument;
    # None acts as a flag to avoid shape checking.
    params = self.param('weightnorm_params', None, initializer)
    direction, scale, bias = [params[k] for k in ('direction', 'scale', 'bias')]
    return conv(inputs, _make_kernel(direction, scale)) + bias

This situation isn’t terrible, but it does highlight the fact that the assumption that initializers depend on parameter shape and nothing else is a bit arbitrary.

A more flexible API, with initializer requiring only a JAX PRNG key, would mean more consistent implementations of different types of initializers, and might also help to correctly emphasize what Flax’s role is here, namely to handle splitting and passing of PRNG keys to initializers and to setup the parameters data-structure (a nested dictionary).

Proposal

We propose to change the call signature of Module.param from

param(self, name: str, shape: Shape, initializer: Callable[[Key, Shape], Array]):

to

param(self, name: str, initializer: Callable[[Key], Array]):

This change would lead to a slight simplification of the weightnorm example above, since the dummy shape arguments could be removed.

For existing Modules for which the initializer is a straightforward function of the parameter shape and no other data is required, we can alter the currying of the initializer definitions so that lines like

kernel = self.param('kernel', kernel_shape, kernel_init)

can be replaced by

kernel = self.param('kernel', kernel_init(kernel_shape))

There may be downsides to this approach which I, being a relative Flax noob, am unaware of. One thing is that we’d lose the shape checking in Module.params, but that seems like the kind of check which should be part of a test anyway.

Alternatives

I think the obvious alternative is to simply keep the current API. The change proposed above is relatively minor but it still would likely require a number of users to make changes to their own code.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
jheekcommented, Apr 1, 2020

I agree the shape argument is arbitrary. It also implies that we can only return a single tensor which is also arbitrary.

Note that Flax is using the initializers from jax.nn.initializers so that create some difficulty when we want to reuse those.

An alternative that is less invasive is the following:

Change the current param api:

def param(self, shape, init_fn):
  ...
  init_value = init_fn(rng, shape)

to:

def param(self[, shape=None], init_fn):
  if shape is None:
    init_value = init_fn(rng):
  else:
    init_value = init_fn(rng, shape)
0reactions
avitalcommented, Dec 12, 2020

This has been resolved (with an improved shape check) in #530

Read more comments on GitHub >

github_iconTop Results From Across the Web

inspect — Inspect live objects — Python 3.11.1 documentation
The Signature object represents the call signature of a callable object and its return annotation. To retrieve a Signature object, use the signature()...
Read more >
On Rejection Sampling in Lyubashevsky's Signature Scheme
Abstract. Lyubashevsky's signatures are based on the Fiat-Shamir with aborts paradigm, whose central ingredient is the use of rejection sampling.
Read more >
Module: feature — skimage v0.19.2 docs
The mode parameter determines how the array borders are handled during Gaussian filtering, ... If indices = False : Boolean array shaped like...
Read more >
typescript - Extract call signature from a type - Stack Overflow
You can use ReturnType and Parameters built-in types to extract the parameters and return type and rebuild the signature:
Read more >
PTX ISA :: CUDA Toolkit Documentation
The programming guide to using PTX (Parallel Thread Execution) and ISA (Instruction Set Architecture).
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found