question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Inconsistent batching behavior between MLPs and convnets

See original GitHub issue

The following code snippet shows that Stax MLPs can be defined w.r.t. unbatched examples (input_size = (1,)) while Convnets seem to require a batch size (though it can be -1). Is this intended behavior?

# Works
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'), Relu,
    Conv(64, (3, 3), padding='SAME'), Relu,
    MaxPool((2, 2)), Flatten,
    Dense(128), Relu,
    Dense(10), LogSoftmax,
)

# Initialize parameters, not committing to a batch shape
in_shape = (-1, 28, 28, 1)
out_shape, net_params = net_init(in_shape)

# Works
net_init, net_apply = stax.serial(
    Dense(40), Relu,
    Dense(40), Relu,
    Dense(1)
)
in_shape = (1,)
out_shape, net_params = net_init(in_shape)

# Doesn't Work
net_init, net_apply = stax.serial(
    Conv(32, (3, 3), padding='SAME'), Relu,
    Conv(64, (3, 3), padding='SAME'), Relu,
    MaxPool((2, 2)), Flatten,
    Dense(128), Relu,
    Dense(10), LogSoftmax,
)
in_shape = (28, 28, 1)
out_shape, net_params = net_init(in_shape)

The last one returns the following error: IndexError Traceback (most recent call last) <ipython-input-4-d145a7688535> in <module>() 9 # Initialize parameters, not committing to a batch shape 10 in_shape = (28, 28, 1) —> 11 out_shape, net_params = net_init(in_shape)

google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape) 269 params = [] 270 for init_fun in init_funs: –> 271 input_shape, param = init_fun(input_shape) 272 params.append(param) 273 return input_shape, params

google3/third_party/py/jax/experimental/stax.py in init_fun(input_shape) 109 kernel_shape = [out_chan if c == ‘O’ else 110 input_shape[lhs_spec.index(‘C’)] if c == ‘I’ else –> 111 next(filter_shape_iter) for c in rhs_spec] 112 output_shape = lax.conv_general_shape_tuple( 113 input_shape, kernel_shape, strides, padding, dimension_numbers)

IndexError: tuple index out of range

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:14 (14 by maintainers)

github_iconTop GitHub Comments

4reactions
ericjangcommented, Feb 15, 2019

Agreed, there definitely should be the flexibility to define batched nets (for models whose forward pass requires minibatches). But given the behavior of MLP, a user can easily suspect that Conv would automatically pretend a singleton batch under the hood if ndims==3. It is confusing if some primitives assume batching while others do not.

On Fri, Feb 15, 2019 at 7:02 AM Peter Hawkins notifications@github.com wrote:

If we removed the batch dimension from stax, it’s not obvious to me how to define batch norm.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/381#issuecomment-464080821, or mute the thread https://github.com/notifications/unsubscribe-auth/AAacMZHDAIZV7SL4uef---NOjzL3FVRFks5vNsv0gaJpZM4a88fG .

3reactions
mattjjcommented, Feb 15, 2019

We’ve got a plan! Will update this issue as we make progress.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Dropout vs. batch normalization: an empirical study of their ...
In this paper we conduct an empirical study to investigate the effect of dropout and batch normalization on training deep learning models.
Read more >
1D Convolutional Neural Network Models for Human Activity ...
We can batch the loading of these files into groups given the consistent directory structures and file naming conventions. The input data is...
Read more >
What are the advantages of a convolutional neural network ...
Some of the advantages of a CNN over an MLP for images are that they are more location invariant due to the convolution...
Read more >
Accelerating Deep Neural Network Training with Inconsistent ...
Then we present the study of ISGD batch size to the learning rate, parallelism, synchronization cost, system saturation and scalability. We conclude the...
Read more >
ICLR2023 Statistics - Guoqiang Wei
# (4042) Title R1 R6 R6‑std ∆R Ratings 9 Fast Nonlinear Vector Quantile Regression 8.00 8.00 0.00 0.00 8, 8, 8. 8, 8, 8 13...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found