BatchNormalization does not support integer inputs.
See original GitHub issueSystem information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): No
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): MacOS 10.15.7
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): v2.7.0-rc1-69-gc256c071bb2 2.7.0
- Python version: 3.8.12
- Bazel version (if compiling from source): N/A
- GPU model and memory: N/A
- Exact command to reproduce: see below
Describe the problem.
The following code fails because the BatchNormalization
layer does not support inputs of type uint8
(see the full stacktrace below). You can run this code in this gist.
import tensorflow as tf
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10, activation="softmax")
])
model.compile(loss="sparse_categorical_crossentropy", optimizer="sgd")
model.fit(X_train, y_train, epochs=2)
Describe the current behavior.
The code above raises the following exception: TypeError: Exception encountered when calling layer "batch_normalization_1" (type BatchNormalization). Input 'y' of 'AddV2' Op has type float32 that does not match type uint8 of argument 'x'.
See the full stacktrace below.
Describe the expected behavior.
The BatchNormalization
layer should automatically cast integer inputs to floats.
If you remove the BatchNormalization
, everything works fine, because the Dense
layer casts its inputs to floats automatically. I expect these layers to behave in the same way: either both of them should cast integers to floats when needed, or neither of them should. IMO the first option is preferable.
- Do you want to contribute a PR? (yes/no): yes
- Briefly describe your candidate solution(if contributing): in the
BatchNormalization
layer, cast integer inputs tofloat32
.
Source code / logs.
Full stack trace:
Epoch 1/2
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-2-15a576831687> in <module>()
7 ])
8 model.compile(loss="sparse_categorical_crossentropy", optimizer="sgd")
----> 9 model.fit(X_train, y_train, epochs=2)
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
1127 except Exception as e: # pylint:disable=broad-except
1128 if hasattr(e, "ag_error_metadata"):
-> 1129 raise e.ag_error_metadata.to_exception(e)
1130 else:
1131 raise
TypeError: in user code:
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 878, in train_function *
return step_function(self, iterator)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 867, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in run_step **
outputs = model.train_step(data)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 808, in train_step
y_pred = self(x, training=True)
File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
TypeError: Exception encountered when calling layer "batch_normalization" (type BatchNormalization).
Input 'y' of 'AddV2' Op has type float32 that does not match type uint8 of argument 'x'.
Call arguments received:
• inputs=tf.Tensor(shape=(32, 784), dtype=uint8)
• training=True
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:8 (8 by maintainers)
I have submitted a change that makes all normalization layers cast their input to their
compute_dtype
, which resolves this issue. We can do this in additional layers too if there’s demand.The way casting current works in Keras layers, is that each layer has a “dtype policy” which contains a “variable dtype” and a “compute dtype”. By default both are equal to float32, but they have have different values (e.g. in mixed precision you’d use a policy with a float32 variable dtype and a float16 compute dtype.
All layers will cast their inputs to their compute dtype. BUT this only happens for floating point inputs (e.g. casting float64 to float32). In your case no casting happens because the input is integer type.
We probably have two options here:
I’m trying to think if there are cases where 1) would be obviously incorrect. Maybe image preprocessing layers? But even then the rule “cast to compute dtype” is simple and consistent. @mattdangerw I remember we look at this in the context of KPL, do you remember what our conclusion was?
I’d favor doing 1) (at the level of the base layer) unless we find a significant reason why this would be incorrect. However, backwards compatibility constraints might prevent us from doing so for layers where uint8 is currently accepted and returns uint8 outputs.