question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

BatchNormalization on non-last axis throws ValueError with Tensorflow backend

See original GitHub issue
from keras.layers import BatchNormalization, Input
x = Input(shape=(1, 2, 2))
BatchNormalization(axis=1)(x)

Above code fails with ValueError: Shape must be rank 1 but is rank 4 for 'batch_normalization_1/cond/FusedBatchNorm' (op: 'FusedBatchNorm') with input shapes: [?,1,2,2], [1,1,1,1], [1,1,1,1], [1,1,1,1], [1,1,1,1]. when using Tensorflow as backend, meaning the BatchNormalization layer can not be used with the channels_first format currently.

This seems to be caused by the changes from #10207. tf.nn.fused_batch_norm expects 1D tensors as non-input parameters but the inference part of keras.layers.normalization.BatchNormalization.call calls keras.backend.batch_normalization with 4D tensors when it does broadcasting. That broadcasting is not required for TF’s fused batch norm.

Full traceback:

InvalidArgumentError                      Traceback (most recent call last)
C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1566   try:
-> 1567     c_op = c_api.TF_FinishOperation(op_desc)
   1568   except errors.InvalidArgumentError as e:

InvalidArgumentError: Shape must be rank 1 but is rank 4 for 'batch_normalization_1/cond/FusedBatchNorm' (op: 'FusedBatchNorm') with input shapes: [?,1,2,2], [1,1,1,1], [1,1,1,1], [1,1,1,1], [1,1,1,1].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-1-d7f9f59b50e7> in <module>()
      1 from keras.layers import BatchNormalization, Input
      2 x = Input(shape=(1, 2, 2))
----> 3 BatchNormalization(axis=1)(x)

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
    458             # Actually call the layer,
    459             # collecting output(s), mask(s), and shape(s).
--> 460             output = self.call(inputs, **kwargs)
    461             output_mask = self.compute_mask(inputs, previous_mask)
    462 

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\layers\normalization.py in call(self, inputs, training)
    202         return K.in_train_phase(normed_training,
    203                                 normalize_inference,
--> 204                                 training=training)
    205 
    206     def get_config(self):

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py in in_train_phase(x, alt, training)
   3067 
   3068     # else: assume learning phase is a placeholder tensor.
-> 3069     x = switch(training, x, alt)
   3070     if uses_learning_phase:
   3071         x._uses_learning_phase = True

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py in switch(condition, then_expression, else_expression)
   3002         x = tf.cond(condition,
   3003                     then_expression_fn,
-> 3004                     else_expression_fn)
   3005     else:
   3006         # tf.where needs its condition tensor

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
    430                 'in a future version' if date is None else ('after %s' % date),
    431                 instructions)
--> 432       return func(*args, **kwargs)
    433     return tf_decorator.make_decorator(func, new_func, 'deprecated',
    434                                        _add_deprecated_arg_notice_to_docstring(

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in cond(pred, true_fn, false_fn, strict, name, fn1, fn2)
   2070     context_f = CondContext(pred, pivot_2, branch=0)
   2071     context_f.Enter()
-> 2072     orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
   2073     if orig_res_f is None:
   2074       raise ValueError("false_fn must have a return value.")

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\control_flow_ops.py in BuildCondBranch(self, fn)
   1911     """Add the subgraph defined by fn() to the graph."""
   1912     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
-> 1913     original_result = fn()
   1914     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1915     if len(post_summaries) > len(pre_summaries):

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\layers\normalization.py in normalize_inference()
    163                     broadcast_gamma,
    164                     axis=self.axis,
--> 165                     epsilon=self.epsilon)
    166             else:
    167                 return K.batch_normalization(

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\keras\backend\tensorflow_backend.py in batch_normalization(x, mean, var, beta, gamma, axis, epsilon)
   1892                 variance=var,
   1893                 data_format=tf_data_format,
-> 1894                 is_training=False
   1895             )
   1896             return y

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\nn_impl.py in fused_batch_norm(x, scale, offset, mean, variance, epsilon, data_format, is_training, name)
    902       data_format=data_format,
    903       is_training=is_training,
--> 904       name=name)
    905   return y, batch_mean, batch_var
    906 

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\ops\gen_nn_ops.py in _fused_batch_norm(x, scale, offset, mean, variance, epsilon, data_format, is_training, name)
   3772         "FusedBatchNorm", x=x, scale=scale, offset=offset, mean=mean,
   3773         variance=variance, epsilon=epsilon, data_format=data_format,
-> 3774         is_training=is_training, name=name)
   3775     _result = _op.outputs[:]
   3776     _inputs_flat = _op.inputs

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    785         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    786                          input_types=input_types, attrs=attr_protos,
--> 787                          op_def=op_def)
    788       return output_structure, op_def.is_stateful, op
    789 

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\ops.py in create_op(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_shapes, compute_device)
   3390           input_types=input_types,
   3391           original_op=self._default_original_op,
-> 3392           op_def=op_def)
   3393 
   3394       # Note: shapes are lazily computed with the C API enabled.

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1732           op_def, inputs, node_def.attr)
   1733       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1734                                 control_input_ops)
   1735     else:
   1736       self._c_op = None

C:\ProgramData\Anaconda3\envs\keras\lib\site-packages\tensorflow\python\framework\ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1568   except errors.InvalidArgumentError as e:
   1569     # Convert to ValueError for backwards compatibility.
-> 1570     raise ValueError(str(e))
   1571 
   1572   return c_op

ValueError: Shape must be rank 1 but is rank 4 for 'batch_normalization_1/cond/FusedBatchNorm' (op: 'FusedBatchNorm') with input shapes: [?,1,2,2], [1,1,1,1], [1,1,1,1], [1,1,1,1], [1,1,1,1].

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Reactions:11
  • Comments:12 (2 by maintainers)

github_iconTop GitHub Comments

19reactions
rockcatcommented, Jun 27, 2018

Hi

Just copying over answer from Stack overflow which works:

Try keras 2.1.6

pip uninstall keras pip install -I keras==2.1.6

4reactions
n-westcommented, Oct 16, 2018

This bug still exists in recent releases (2.2.4). Since it doesn’t exist in 2.2.2, this seems like a regression in 2.2.x series, but isn’t being included in bug fix releases. Is it possible to get it included in the next bug fix release and close the issue?

Read more comments on GitHub >

github_iconTop Results From Across the Web

tf.keras.layers.BatchNormalization | TensorFlow v2.11.0
Batch normalization applies a transformation that maintains the mean ... the axis that should be normalized (typically the features axis).
Read more >
Keras error: "BatchNormalization Shape must be rank 1 but is ...
I have a Keras functional model (Neural network with convolutional layers) which works fine with tensorflow. I can run it and I can...
Read more >
tf.keras.backend.batch_normalization | TensorFlow
Defined in tensorflow/python/keras/backend.py . Applies batch normalization on x given mean, var, beta and gamma. I.e. returns: output = (x - mean) ...
Read more >
layer_batch_normalization - TensorFlow for R - RStudio
Normalize the activations of the previous layer at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and...
Read more >
Batch Normalization in practice: an example with Keras and ...
A step by step tutorial to add and customize batch normalization ... Keras and TensorFlow 2.0 only take in Numpy array as inputs, ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found