question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Variable batch sizes in flax.nn.DenseGeneral

See original GitHub issue

Simple use case: I have some input data that the batch size does not divide evenly.

In flax.nn.Dense, the kernel shape is (input.shape[-1], features), which means each batch item has the same kernel and we can pass arbitrarily-sized batches to the same model and get back

By contrast, flax.nn.DenseGeneral has a kernel shape batch_shape + kernel_shape which locks us into a batch size (see self-contained example below). We also train a factor of batch_shape more parameters as a result, where each item in a batch has its own copy of kernel.

This second point might not matter, but the behavior is different from flax.nn.Dense; I had previously thought flax.nn.Dense was implemented using flax.nn.DenseGeneral.

The first solution that comes to mind is we ignore the value in the batch dimensions when we initialize the kernel/bias parameters, much as flax.nn.Dense does. However, we would need to be careful if the user has no batch dimensions.

The above might be intended behavior, but just wanted to note my surprise / desire for something a little different.

import flax
import jax
import numpy as np

model_def = flax.nn.DenseGeneral.partial(
    features=1,
    axis=(1, 2),
    batch_dims=0,
    kernel_init=flax.nn.initializers.kaiming_normal()
)

_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(2, 1, 57)])

print(params["bias"].shape, params["kernel"].shape)
# (2, 1) (2, 1, 57, 1)

X = np.random.rand(3, 1, 57)

model(X)
# ValueError: Existing shape (2, 1, 57, 1) differs from requested shape (3, 1, 57, 1)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
jheekcommented, May 14, 2020

Yes it is pretty much a wrapper around dot_general. The complexity arises due to kernel initialization. So to be clear a batch axis in ML is an axis that only appears in the inputs in a batched matmul the batch dim appears in both the input and the kernel. So axis dims are contracted over, while batch_dims appear in both the input and the kernel, any remaining dimension that appear in the input will be multiplied with the same kernel (ML style batching)

0reactions
danielsuocommented, May 14, 2020

Yep, makes sense! Once I realized DenseGeneral is essentially just a wrapper around dot_general, there was no more confusion. I’d still note my original surprise, but I’m not sure how big a deal that is vs. making a breaking change.

Anyway, we kept a version of my example implementation in timecast here.

I’ll close this issue and the associated PR.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Flax Basics - Read the Docs
Flax finds out by itself the correct size of the kernel. ... and every type of predefined layers in Flax (like the previous...
Read more >
Managing Parameters and State - Flax - Read the Docs
pmean() must be used to average the statistics over the batch dimension so that the state is in sync for each item in...
Read more >
flax.linen package
Initializes a module method with variables and returns output and modified variables. Parameters. rngs – The rngs for the variable collections. *args –...
Read more >
Preface - Flax - Read the Docs
In this case, we only have to provide the output features dimension. model = nn.Dense(features=3). We need to initialize the Module variables, these...
Read more >
flax.linen.normalization
Usage Note: If we define a model with BatchNorm, for example:: BN = nn. ... momentum=0.9, epsilon=1e-5, dtype=jnp.float32) The initialized variables dict ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found