RFC: Quantizers as subclasses of keras.layer.Layer
See original GitHub issueThis Request for Comments (RFC) outlines a proposed change to API and internal implementation of larq.quantizers
. It would be good to get feedback on this, so please ask questions if I need to elaborate more in some areas.
Objective
We currently have three possible ways of accessing larq.quantizers
:
- As a string:
ste_sign
- As a class with optional configurable arguments:
SteSign(clip_value=1.0)
- As a function:
larq.quantizers.ste_sign
This raises a few question as it can be confusing due to having different ways to achieve the same thing which we also discussed in https://github.com/larq/larq/issues/246#issuecomment-536531445 previously.
Also the way we currently handle metrics is suboptimal since they are added to larq.layers
directly which makes implementation messy and doesn’t allow for metrics that access parts of the custom gradients. Which make it hard to implement ideas like #245 in a general way.
Design Proposal
In this RFC I’d like to propose three changes to the current API and implementation which should solve the above problems.
Proposal 1: Unify larq.quantizers
API #387
The functional way (see option 3. above) of calling larq.quantizers
doesn’t add much value to larq
and is only a shorthand for calling the class based API which has much more flexibility. This adds overhead in the source code of larq
and will make Proposal 2 and 3 unnecessarily hard to implement. It might also be confusing for people when looking at the API docs to see two almost identical ways to use larq.quantizers
. I propose to remove this option.
Proposal 2: larq.quantizers
should subclass keras.layers.Layer
#387
It is currently possible to use custom keras.layers.Layer
s as larq.quantizers
(see #322). This can be useful when implementing quantizers with trainable variables or other variables that keep a running mean of scaling, commonly used in 8-bit quantizers. This change is only an implementation detail needed for Proposal 3 and doesn’t change the way people currently use larq.quantizers
.
Proposal 3: Add metrics to larq.quantizers
instead of larq.layers
The way we currently add metrics like flip_ratio
to larq.layers
is messy and doesn’t allow for metrics accessing custom gradients which blocks #245.
After implementing proposal 2, we can add relevant metrics directly to larq.quantizers
which would allow for better control over which metrics to add to which quantizer and allows direct access to custom gradients.
Implementation
The implementation could looks something like this (not including configuration of metrics or special handling for eval/train):
import tensorflow as tf
import larq as lq
class GradientFlow(tf.keras.metrics.Metric):
def __init__(self, name="gradient_flow", dtype=None):
super().__init__(name=name, dtype=dtype)
with tf.init_scope():
self.total_value = self.add_weight(
"total_value", initializer=tf.keras.initializers.zeros
)
self.num_batches = self.add_weight(
"num_batches", initializer=tf.keras.initializers.zeros
)
def update_state(self, values):
values = tf.cast(values, self.dtype)
non_zero = tf.math.count_nonzero(values, dtype=self.dtype)
num_activations = tf.cast(tf.size(values), self.dtype)
update_total_op = self.total_value.assign_add(non_zero / num_activations)
with tf.control_dependencies([update_total_op]):
return self.num_batches.assign_add(1)
def result(self):
return tf.math.divide_no_nan(self.total_value, self.num_batches)
class SteSign(tf.keras.layers.Layer):
def __init__(self, clip_value=1.0, **kwargs):
self.precision = 1
self.clip_value = clip_value
super().__init__(**kwargs)
def build(self, input_shape):
self.flip_ratio = lq.metrics.FlipRatio(
values_shape=input_shape, name=f"flip_ratio/{self.name}"
)
self.gradient_flow = GradientFlow(name=f"gradient_flow/{self.name}")
super().build(input_shape)
@tf.custom_gradient
def call(self, inputs):
def grad(dy):
if self.clip_value is None:
return dy
zeros = tf.zeros_like(dy)
mask = tf.math.less_equal(tf.math.abs(inputs), self.clip_value)
clipped_gradients = tf.where(mask, dy, zeros)
self.add_metric(self.gradient_flow(clipped_gradients))
return clipped_gradients
outputs = lq.math.sign(inputs)
self.add_metric(self.flip_ratio(outputs))
return outputs, grad
@property
def non_trainable_weights(self):
return []
def get_config(self):
return {**super().get_config(), "clip_value": self.clip_value}
Issue Analytics
- State:
- Created 4 years ago
- Comments:7 (7 by maintainers)
Top GitHub Comments
If there are no more comments from @larq/core I’ll set the status of this proposal to be accepted and we can start implementing it in
larq
.Not quite. You can still use it like
kernel_quantizer=SteSign(clip_value=1.0)
. Checkout the unittest added in #322