`RandomAugmentationPipeline` serialization test fails for `layers!=[]`
See original GitHub issueIn layers/serialization_test.py, currently the RandomAugmentationPipeline
serialization test is as follows:
https://github.com/keras-team/keras-cv/blob/92b4e16fa5686e861b2f2cd7d04a6c5d7d3c0921/keras_cv/layers/serialization_test.py#L98-L102
However, if we change it as follows the test fails.
(
"RandomAugmentationPipeline",
preprocessing.RandomAugmentationPipeline,
{
"layers": [preprocessing.RandomSaturation(0.5)],
"augmentations_per_image": 1,
"rate": 1.0
},
),
If I’m not mistaken this implies that the layers
argument, being a list, is not getting correctly tested. We are facing the same problem as earlier, the test compares the memory location and not the logical value. When we manually print the internal layer attributes, they turn out to be the same.
# Original config:
{'name': 'random_augmentation_pipeline', 'trainable': True, 'dtype': 'float32', 'augmentations_per_image': 1, 'rate': 1.0, 'layers': ListWrapper([<keras_cv.layers.preprocessing.random_saturation.RandomSaturation object at 0x0000025C55B18488>]), 'seed': None}
# Reconstructed config
{'name': 'random_augmentation_pipeline', 'trainable': True, 'dtype': 'float32', 'augmentations_per_image': 1, 'rate': 1.0, 'layers': ListWrapper([<keras_cv.layers.preprocessing.random_saturation.RandomSaturation object at 0x0000025C58530788>]), 'seed': None}
# keras_cv\layers\serialization_test.py: 233
# Checking correctness of serialized and deserialized objects
def test_layer_serialization(self, layer_cls, init_args):
layer = layer_cls(**init_args)
if "seed" in init_args:
self.assertIn("seed", layer.get_config())
model = tf.keras.models.Sequential(layer)
model_config = model.get_config()
reconstructed_model = tf.keras.Sequential().from_config(model_config)
reconstructed_layer = reconstructed_model.layers[0]
if layer_cls == preprocessing.RandomAugmentationPipeline:
print(layer.get_config()["layers"][0].factor.upper)
print(reconstructed_layer.get_config()["layers"][0].factor.upper)
self.assertTrue(
config_equals(layer.get_config(), reconstructed_layer.get_config())
)
# Output:
# 0.5
# 0.5
I think we need a more robust test to check serialization and deserialization. Otherwise these problems will keep arising as the complexity of the layers (and their arguments) increases. Please share your thoughts on the same.
Issue Analytics
- State:
- Created a year ago
- Comments:8 (4 by maintainers)
Top GitHub Comments
Thank you @AdityaKane2001 !
Sounds good to me Aditya!