Variable batch sizes in flax.nn.DenseGeneral
See original GitHub issueSimple use case: I have some input data that the batch size does not divide evenly.
In flax.nn.Dense
, the kernel shape is (input.shape[-1], features)
, which means each batch item has the same kernel and we can pass arbitrarily-sized batches to the same model and get back
By contrast, flax.nn.DenseGeneral
has a kernel shape batch_shape + kernel_shape
which locks us into a batch size (see self-contained example below). We also train a factor of batch_shape
more parameters as a result, where each item in a batch has its own copy of kernel
.
This second point might not matter, but the behavior is different from flax.nn.Dense
; I had previously thought flax.nn.Dense
was implemented using flax.nn.DenseGeneral
.
The first solution that comes to mind is we ignore the value in the batch dimensions when we initialize the kernel/bias parameters, much as flax.nn.Dense
does. However, we would need to be careful if the user has no batch dimensions.
The above might be intended behavior, but just wanted to note my surprise / desire for something a little different.
import flax
import jax
import numpy as np
model_def = flax.nn.DenseGeneral.partial(
features=1,
axis=(1, 2),
batch_dims=0,
kernel_init=flax.nn.initializers.kaiming_normal()
)
_, params = model_def.init_by_shape(jax.random.PRNGKey(0), [(2, 1, 57)])
print(params["bias"].shape, params["kernel"].shape)
# (2, 1) (2, 1, 57, 1)
X = np.random.rand(3, 1, 57)
model(X)
# ValueError: Existing shape (2, 1, 57, 1) differs from requested shape (3, 1, 57, 1)
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (3 by maintainers)
Top GitHub Comments
Yes it is pretty much a wrapper around dot_general. The complexity arises due to kernel initialization. So to be clear a batch axis in ML is an axis that only appears in the inputs in a batched matmul the batch dim appears in both the input and the kernel. So
axis
dims are contracted over, whilebatch_dims
appear in both the input and the kernel, any remaining dimension that appear in the input will be multiplied with the same kernel (ML style batching)Yep, makes sense! Once I realized
DenseGeneral
is essentially just a wrapper arounddot_general
, there was no more confusion. I’d still note my original surprise, but I’m not sure how big a deal that is vs. making a breaking change.Anyway, we kept a version of my example implementation in
timecast
here.I’ll close this issue and the associated PR.