Support for nested `tf.keras.layers.Layer`
See original GitHub issueSystem information
-
Ubuntu 16.04 -
TensorFlow version (you are using): Tried both
2.3.0andtf-nightly(==2.5.0-dev20210223) -
Are you willing to contribute it (Yes/No): I have tried, but to no avail - maybe with some help?
Motivation
Distantly related to #155 , but instead of nested models this is a request for supporting nested layers. Perhaps the necessary steps for developing the feature might be quite similar, but potentially with less constraints that are imposed by the requirement of tf.keras.Model.
As documented in the official tensorflow guide, one of the recommended ways of building a TF model is via recursive composition of layers. A lot of architectures have natural such groups, and it would be benificial if tfmot could natively support this kind of a design pattern.
If it helps, the specific application that I was considering is applying tfmot to an optical flow network in the style of PWCNet. This would entail multiple levels of hierarchical blocks and reuse of such blocks in an hourglass architecture. Unfortunately, this undertaking has proved to be quite an awkward fit with the current design of tfmot. For reference, the quantization part of the project can be seen here.
Describe the feature
As a first milestone, perhaps layers that are pure compositions of tf.keras.layers.Layer can be supported, i.e.
class SampleNestedLayer(tf.keras.layers.Layer):
def __init__(self, *args, **kwds):
super().__init__(*args, **kwds)
self.conv = tf.keras.layers.Conv2D(16, 3, padding='valid')
self.norm = tf.keras.layers.BatchNormalization()
def call(self, x):
return self.norm(self.conv(x))
I’m aware that the above layer, if not written as a subclassed layer of layers, can be quantized - please take it as an example that, despite its simplicity and straightforward implementation that complies to the general guidelines of development with tf2/keras, nonetheless is not supported as-is in the current state without rewriting the layer.
Describe how the feature helps achieve the use case
The above example, if supported as-is, would serve as a first-pass (and even just pure compositions of existing keras layers would already bring significant expressive freedom) for supporting generally nested Layers.
Describe how existing APIs don’t satisfy your use case (optional if obvious)
I have tried a couple of approaches, but I think the difficulty in supporting nested layers comes in multiple pieces.
- Currently, the layer annotation is enumerated via
model.layerswhich does not provide introspection into nested layers; - Nested models are not currently supported, so it is not a viable workaround;
- The quantization APIs via
QuantizeConfigandModelTransformerare not compatible; ModelTransformerdoes not support reused layers (which are common in e.g. optical flow networks with shared feature extractors).
I have attempted something like the following:
- subclassing
QuantizeConfig
class RecursiveDelegateConfig(QuantizeConfig):
"""
`QuantizeConfig` class that recursively supports sub-layers
that are supported by the provided registry. Does not work with
layers that require ModelTransformer.
"""
def __init__(self, registry=None):
if registry is None:
registry = default_8bit_quantize_registry.Default8BitQuantizeRegistry()
self.registry = registry
self.wmap = OrderedDict() # store lengths of weights at each layer
self.amap = OrderedDict() # store lengths of activations at each layer
@staticmethod
def get_sub_layers(layer: tf.keras.layers.Layer):
layers = layer._flatten_layers(recursive=False, include_self=False)
return sorted([(l.name, l) for l in layers])
def get_weights_and_quantizers(self, layer):
# First, try if supported by the default registry.
if self.registry.supports(layer):
config = self.registry.get_quantize_config(layer)
out = config.get_weights_and_quantizers(layer)
return out
# Otherwise, assume this is a pure composition of keras layers
# and process recursively. The requirement here is that
# all leaf-node layers must be supported. Also,
# `self` cannot have its own weights for now - just to preserve sanity.
out = []
sub_layers = self.get_sub_layers(layer)
for name, sub_layer in sub_layers:
# NOTE(ycho): Might want to dedup_weights
wnq = self.get_weights_and_quantizers(sub_layer)
self.wmap[name] = len(wnq)
out.extend(wnq)
if not out:
logging.warn(
'empty output : perhaps there was an error? {}'.format(layer))
return out
def get_activations_and_quantizers(self, layer):
if self.registry.supports(layer):
config = self.registry.get_quantize_config(layer)
return config.get_activations_and_quantizers(layer)
out = []
sub_layers = self.get_sub_layers(layer)
for name, sub_layer in sub_layers:
anq = self.get_activations_and_quantizers(sub_layer)
self.amap[name] = len(anq)
out.extend(anq)
return out
def set_quantize_weights(self, layer, quantize_weights):
if self.registry.supports(layer):
config = self.registry.get_quantize_config(layer)
return config.set_quantize_weights(layer, quantize_weights)
sub_layers = self.get_sub_layers(layer)
for name, sub_layer in sub_layers:
n = self.wmap[name]
self.set_quantize_weights(sub_layer, quantize_weights[:n])
quantize_weights = quantize_weights[n:]
def set_quantize_activations(self, layer, quantize_activations):
if self.registry.supports(layer):
config = self.registry.get_quantize_config(layer)
return config.set_quantize_activations(layer, quantize_activations)
sub_layers = self.get_sub_layers(layer)
for name, sub_layer in sub_layers:
n = self.amap[name]
self.set_quantize_activations(sub_layer, quantize_activations[:n])
quantize_activations = quantize_activations[n:]
def get_output_quantizers(self, layer):
if self.registry.supports(layer):
config = self.registry.get_quantize_config(layer)
return config.get_output_quantizers(layer)
sub_layers = self.get_sub_layers(layer)
out = []
for name, sub_layer in sub_layers:
out.extend(self.get_output_quantizers(sub_layer))
return out
@classmethod
def from_config(cls, config):
return cls(**config)
def get_config(self):
return {
'registry': self.registry, # FIXME(ycho): probably doesn't work
}
def __eq__(self, other):
return isinstance(
other, RecursiveDelegateConfig) and (
self.get_config() == other.get_config())
def __ne__(self, other):
return not self.__eq__(other)
Which tries to address this via extending the QuantizeConfig API. Unfortunately, layers such as BatchNormalization are only quantizable via ModelTransformer pattern-matching, and cannot be supported with this approach.
- Then I tried to extend the
ModelTransformerapproach to flatten nested layers, which ended up not being possible due to the absence of support for layers with multiple connections (e.g. reused feature extractors). Copying the relevant line below frommodel_transformer.py#L138:
if len(inbound_nodes) > 1:
# `layer` is re-used for more than 1 connection from previous layers. If
# a pattern matches one set of inputs and is replaced, it will break the
# other connection.
#
# Note that theoretically it's possible to have multiple connections have
# exactly the same pattern, and in that case the transform might be
# applied. But that's a very complicated edge case not worth handling.
return False
Moreover, it can be quite unintuitive to try to replicate exactly the underlying operation of the subclassed nested layer with the flat analogue.
- Subclassing a model and recursively applying the transformer does not work due to upstream (keras) limitations on cloning.
- In the end, the only approach that worked was to manually flatten all the sub-layers (which unfortunately ended up with an undecipherable graph and a hideous coding pattern, as well as a large number of
TensorFlowOpLayers).
Hopefully that grounds the picture to some extent - whereas I think the current tfmot architecture is already very well architected, I think being able to support this feature would make it a lot more powerful than it already is. While my 4-or-so attempts have all ended up being futile, I was wondering if there is a roadmap in the tfmot dev team to support nested layers. I haven’t seen explicit mentions of it in the issues (while some were similar), so I figured it would be worth bringing it up.
Thank you!
Yoonyoung (Jamie) Cho
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:5 (1 by maintainers)

Top Related StackOverflow Question
Seems reasonable – we are planning on adding this soon.
Just noting that subclassed layers don’t work with tf.function, nor do subclassed models. This is a very weird bug: if the function called directly in the training loop is a tf.function, the GPU will not be utilized (although tensors are logged to the gpu). If we separate the gradienttapes into different functions with separate tf.functions, it will work.