Update error message for GroupNorm default argument `group_size`
See original GitHub issueProblem you have encountered:
As default argument group_size
is not set to None
link the logic here fails. Here’s a small code to repro:
import jax
from jax import numpy as jnp
import flax
rng = jax.random.PRNGKey(0)
gn = flax.linen.GroupNorm(group_size=1)
params = gn.init(rng, jnp.ones((2,2,2)))
What you expected to happen:
Setting group_size
alone should be allowed and logic also reflects that. Looks like a small typo to correct though.
Logs, error messages, etc:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-15-3fa6bd6dcb5c> in <module>
5 rng = jax.random.PRNGKey(0)
6 gn = flax.linen.GroupNorm(group_size=2)
----> 7 params = gn.init(rng, jnp.ones((2,2,2)))
[... skipping hidden 11 frame]
/usr/local/lib/python3.8/dist-packages/flax/linen/normalization.py in __call__(self, x)
320 if ((self.num_groups is None and self.group_size is None) or
321 (self.num_groups is not None and self.group_size is not None)):
--> 322 raise ValueError('Either `num_groups` or `group_size` should be '
323 'specified, but not both of them.')
324 num_groups = self.num_groups
ValueError: Either `num_groups` or `group_size` should be specified, but not both of them.
Steps to reproduce:
Look at code snippet above.
Issue Analytics
- State:
- Created 2 years ago
- Comments:9 (3 by maintainers)
Top Results From Across the Web
Group Normalization | Committed towards better future
Another key thing to note, the validation error for GN as ... Finally, for group norm, the batch is first divided into groups...
Read more >Group Norm (GN): Group Normalization (Image Classification ...
Here G is the number of groups, which is a pre-defined hyper-parameter (G = 32 by default). C/G is the number of channels...
Read more >tf.keras.layers.LayerNormalization | TensorFlow v2.11.0
2018) with group size of 1 corresponds to a Layer Normalization that ... This argument defaults to -1 , the last dimension in...
Read more >GroupNorm — PyTorch 1.13 documentation
Default : 1e-5. affine (bool) – a boolean value that when set to True , this module has learnable per-channel affine parameters initialized...
Read more >flax.linen.normalization
Statistics are always at least float32 for stability (default: dtype of x). ... Arguments: mdl: Module to apply the normalization in (normalization params ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@sudhakarsingh27 thanks! I think that error should be fixed at head now.
Thanks! For more help see our How to contribute guide