tf.keras.callbacks.Tensorboard: write_images does not visualize Conv2D weights
See original GitHub issueI moved this issue from https://github.com/tensorflow/tensorflow/issues/28767
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- TensorFlow version (use command below): v1.12.2-0-gcf74798993 1.12.2
- Python version: 3.6.5
Describe the current behavior When I want to have a look at the weights of Conv2D filters in TensorBoard, only their biases get logged (see attached image). I looked for the corresponding source code and found the following snippet: https://github.com/tensorflow/tensorflow/blob/6612da89516247503f03ef76e974b51a434fb52e/tensorflow/python/keras/callbacks.py#L951-L983 The problem seems to be that Conv2D weights have a 4d shape [H_kernel, W_kernel, C_in, C_out], which is not intended as convolutional layers case in the above code.
Describe the expected behavior I would expect that the convolutional weights are visualized. I know this would be a huge amount of images (C_in * C_out), but I think the current behaviour is confusing.
Code to reproduce the issue
import tensorflow as tf
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
print(type(x_train))
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(32, 32, 3)),
tf.keras.layers.Conv2D(filters=16, kernel_size=3, padding='same', activation='relu'),
tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.summary()
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=f"../../../logs", histogram_freq=1,
write_images=True, write_grads=True)
csvlogger = tf.keras.callbacks.CSVLogger('train.log')
model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, callbacks=[tensorboard, csvlogger], validation_data=(x_test, y_test))
Other info / logs
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:7 (1 by maintainers)
Top GitHub Comments
cc @caisq FYI re tensor visualization
Hi @menon92,
No I did not dive deeper into it. However, the issue seems to be also present in TF2.0. If I find time and look deeper into this, I will share my findings.