Questions about customizing training procedures
See original GitHub issueDescribe the bug
I am trying to use a custom training procedure (inherit keras.model class, redefine a class, train_step and test_step are both defined), using this method to train the quantization network defined by larq, there is a strange situation. When kernels were quantified, the loss value in training was very large. It was normal when activation was quantified only. However, when kernel and activation were quantified simultaneously, the value of training Loss seemed to be within the normal range, but loss did not decrease and accuracy did not increase. But everything looked perfect when I trained directly with model.fit().
To Reproduce
class CNN_model(keras.Model):
def __init__(self, conv_model, compiled_acc):
super(CNN_model, self).__init__()
self.conv_model = conv_model
self.compiled_acc = compiled_acc
def train_step(self, data):
x, y = data
softmax = keras.layers.Softmax()
with tf.GradientTape() as tape:
logits = self.conv_model(x)
loss_value = self.compiled_loss(y, softmax(logits))
acc = self.compiled_acc(y,softmax(logits)) # acc
grads = tape.gradient(loss_value, self.conv_model.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.conv_model.trainable_variables))
self.compiled_metrics.update_state(y, softmax(logits))
return {'train_loss': loss_value, "train_acc":acc}
def test_step(self, data):
x, y = data
softmax = keras.layers.Softmax()
logits = self.conv_model(x)
loss_value = self.compiled_loss(y, softmax(logits))
val_acc = self.compiled_acc(y, softmax(logits))
return {'loss':loss_value, "val_acc":val_acc}
def call(self, inputs):
return self.conv_model(inputs)
Expected behavior
I think the custom training process works just as well as using model.fit directly.
Environment
TensorFlow version: 2.4.0 Larq version: 0.12.0
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (3 by maintainers)
Top GitHub Comments
As Jelmer, mentioned above it would be great to have a minimal reproducible example to be able to debug this.
Just to verify that the custom training loop is working as expected, could you try setting changing:
logits = self.conv_model(x)
tologits = self.conv_model(x, training=True)
?Sounds great! Glad this fixes it for you.