Inconsistent batching behavior between MLPs and convnets
See original GitHub issueThe following code snippet shows that Stax MLPs can be defined w.r.t. unbatched examples (input_size = (1,)) while Convnets seem to require a batch size (though it can be -1). Is this intended behavior?
# Works
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'), Relu,
Conv(64, (3, 3), padding='SAME'), Relu,
MaxPool((2, 2)), Flatten,
Dense(128), Relu,
Dense(10), LogSoftmax,
)
# Initialize parameters, not committing to a batch shape
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(in_shape)
# Works
net_init, net_apply = stax.serial(
Dense(40), Relu,
Dense(40), Relu,
Dense(1)
)
in_shape = (1,)
out_shape, net_params = net_init(in_shape)
# Doesn't Work
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'), Relu,
Conv(64, (3, 3), padding='SAME'), Relu,
MaxPool((2, 2)), Flatten,
Dense(128), Relu,
Dense(10), LogSoftmax,
)
in_shape = (28, 28, 1)
out_shape, net_params = net_init(in_shape)
The last one returns the following error: IndexError Traceback (most recent call last) <ipython-input-4-d145a7688535> in <module>() 9 # Initialize parameters, not committing to a batch shape 10 in_shape = (28, 28, 1) —> 11 out_shape, net_params = net_init(in_shape)
google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape) 269 params = [] 270 for init_fun in init_funs: –> 271 input_shape, param = init_fun(input_shape) 272 params.append(param) 273 return input_shape, params
google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape) 109 kernel_shape = [out_chan if c == ‘O’ else 110 input_shape[lhs_spec.index(‘C’)] if c == ‘I’ else –> 111 next(filter_shape_iter) for c in rhs_spec] 112 output_shape = lax.conv_general_shape_tuple( 113 input_shape, kernel_shape, strides, padding, dimension_numbers)
IndexError: tuple index out of range
Issue Analytics
- State:
- Created 5 years ago
- Comments:14 (14 by maintainers)
Agreed, there definitely should be the flexibility to define batched nets (for models whose forward pass requires minibatches). But given the behavior of MLP, a user can easily suspect that Conv would automatically pretend a singleton batch under the hood if ndims==3. It is confusing if some primitives assume batching while others do not.
On Fri, Feb 15, 2019 at 7:02 AM Peter Hawkins notifications@github.com wrote:
We’ve got a plan! Will update this issue as we make progress.