BatchNormalization on non-last axis throws ValueError with Tensorflow backend
See original GitHub issuefrom keras.layers import BatchNormalization, Input
x = Input(shape=(1, 2, 2))
BatchNormalization(axis=1)(x)
Above code fails with ValueError: Shape must be rank 1 but is rank 4 for 'batch_normalization_1/cond/FusedBatchNorm' (op: 'FusedBatchNorm') with input shapes: [?,1,2,2], [1,1,1,1], [1,1,1,1], [1,1,1,1], [1,1,1,1]. when using Tensorflow as backend, meaning the BatchNormalization layer can not be used with the channels_first format currently.
This seems to be caused by the changes from #10207. tf.nn.fused_batch_norm expects 1D tensors as non-input parameters but the inference part of keras.layers.normalization.BatchNormalization.call calls keras.backend.batch_normalization with 4D tensors when it does broadcasting. That broadcasting is not required for TF’s fused batch norm.
Full traceback:
InvalidArgumentError Traceback (most recent call last)
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1566 try:
-> 1567 c_op = c_api.TF_FinishOperation(op_desc)
1568 except errors.InvalidArgumentError as e:
InvalidArgumentError: Shape must be rank 1 but is rank 4 for 'batch_normalization_1/cond/FusedBatchNorm' (op: 'FusedBatchNorm') with input shapes: [?,1,2,2], [1,1,1,1], [1,1,1,1], [1,1,1,1], [1,1,1,1].
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-1-d7f9f59b50e7> in <module>()
1 from keras.layers import BatchNormalization, Input
2 x = Input(shape=(1, 2, 2))
----> 3 BatchNormalization(axis=1)(x)
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
458 # Actually call the layer,
459 # collecting output(s), mask(s), and shape(s).
--> 460 output = self.call(inputs, **kwargs)
461 output_mask = self.compute_mask(inputs, previous_mask)
462
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\layers\normalization.py in call(self, inputs, training)
202 return K.in_train_phase(normed_training,
203 normalize_inference,
--> 204 training=training)
205
206 def get_config(self):
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py in in_train_phase(x, alt, training)
3067
3068 # else: assume learning phase is a placeholder tensor.
-> 3069 x = switch(training, x, alt)
3070 if uses_learning_phase:
3071 x._uses_learning_phase = True
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py in switch(condition, then_expression, else_expression)
3002 x = tf.cond(condition,
3003 then_expression_fn,
-> 3004 else_expression_fn)
3005 else:
3006 # tf.where needs its condition tensor
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
430 'in a future version' if date is None else ('after %s' % date),
431 instructions)
--> 432 return func(*args, **kwargs)
433 return tf_decorator.make_decorator(func, new_func, 'deprecated',
434 _add_deprecated_arg_notice_to_docstring(
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
2070 context_f = CondContext(pred, pivot_2, branch=0)
2071 context_f.Enter()
-> 2072 orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
2073 if orig_res_f is None:
2074 raise ValueError("false_fn must have a return value.")
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildCondBranch(self, fn)
1911 """Add the subgraph defined by fn() to the graph."""
1912 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
-> 1913 original_result = fn()
1914 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
1915 if len(post_summaries) > len(pre_summaries):
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\layers\normalization.py in normalize_inference()
163 broadcast_gamma,
164 axis=self.axis,
--> 165 epsilon=self.epsilon)
166 else:
167 return K.batch_normalization(
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py in batch_normalization(x, mean, var, beta, gamma, axis, epsilon)
1892 variance=var,
1893 data_format=tf_data_format,
-> 1894 is_training=False
1895 )
1896 return y
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\nn_impl.py in fused_batch_norm(x, scale, offset, mean, variance, epsilon, data_format, is_training, name)
902 data_format=data_format,
903 is_training=is_training,
--> 904 name=name)
905 return y, batch_mean, batch_var
906
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\gen_nn_ops.py in _fused_batch_norm(x, scale, offset, mean, variance, epsilon, data_format, is_training, name)
3772 "FusedBatchNorm", x=x, scale=scale, offset=offset, mean=mean,
3773 variance=variance, epsilon=epsilon, data_format=data_format,
-> 3774 is_training=is_training, name=name)
3775 _result = _op.outputs[:]
3776 _inputs_flat = _op.inputs
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
785 op = g.create_op(op_type_name, inputs, output_types, name=scope,
786 input_types=input_types, attrs=attr_protos,
--> 787 op_def=op_def)
788 return output_structure, op_def.is_stateful, op
789
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\ops.py in create_op(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_shapes, compute_device)
3390 input_types=input_types,
3391 original_op=self._default_original_op,
-> 3392 op_def=op_def)
3393
3394 # Note: shapes are lazily computed with the C API enabled.
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
1732 op_def, inputs, node_def.attr)
1733 self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1734 control_input_ops)
1735 else:
1736 self._c_op = None
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
1568 except errors.InvalidArgumentError as e:
1569 # Convert to ValueError for backwards compatibility.
-> 1570 raise ValueError(str(e))
1571
1572 return c_op
ValueError: Shape must be rank 1 but is rank 4 for 'batch_normalization_1/cond/FusedBatchNorm' (op: 'FusedBatchNorm') with input shapes: [?,1,2,2], [1,1,1,1], [1,1,1,1], [1,1,1,1], [1,1,1,1].
Issue Analytics
- State:
- Created 5 years ago
- Reactions:11
- Comments:12 (2 by maintainers)
Top Results From Across the Web
tf.keras.layers.BatchNormalization | TensorFlow v2.11.0
Batch normalization applies a transformation that maintains the mean ... the axis that should be normalized (typically the features axis).
Read more >Keras error: "BatchNormalization Shape must be rank 1 but is ...
I have a Keras functional model (Neural network with convolutional layers) which works fine with tensorflow. I can run it and I can...
Read more >tf.keras.backend.batch_normalization | TensorFlow
Defined in tensorflow/python/keras/backend.py . Applies batch normalization on x given mean, var, beta and gamma. I.e. returns: output = (x - mean) ...
Read more >layer_batch_normalization - TensorFlow for R - RStudio
Normalize the activations of the previous layer at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and...
Read more >Batch Normalization in practice: an example with Keras and ...
A step by step tutorial to add and customize batch normalization ... Keras and TensorFlow 2.0 only take in Numpy array as inputs, ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Hi
Just copying over answer from Stack overflow which works:
Try keras 2.1.6
pip uninstall keras pip install -I keras==2.1.6
This bug still exists in recent releases (2.2.4). Since it doesn’t exist in 2.2.2, this seems like a regression in 2.2.x series, but isn’t being included in bug fix releases. Is it possible to get it included in the next bug fix release and close the issue?