question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Questions about customizing training procedures

See original GitHub issue

Describe the bug

I am trying to use a custom training procedure (inherit keras.model class, redefine a class, train_step and test_step are both defined), using this method to train the quantization network defined by larq, there is a strange situation. When kernels were quantified, the loss value in training was very large. It was normal when activation was quantified only. However, when kernel and activation were quantified simultaneously, the value of training Loss seemed to be within the normal range, but loss did not decrease and accuracy did not increase. But everything looked perfect when I trained directly with model.fit().

To Reproduce

class CNN_model(keras.Model):
    def __init__(self, conv_model,  compiled_acc):
        super(CNN_model, self).__init__()
        self.conv_model = conv_model
        self.compiled_acc = compiled_acc

    def train_step(self, data):
        x, y = data
        softmax = keras.layers.Softmax()
        with tf.GradientTape() as tape:
            logits = self.conv_model(x)
            loss_value = self.compiled_loss(y, softmax(logits))
            acc = self.compiled_acc(y,softmax(logits))  # acc
        grads = tape.gradient(loss_value, self.conv_model.trainable_variables) 
        self.optimizer.apply_gradients(zip(grads, self.conv_model.trainable_variables))  
        self.compiled_metrics.update_state(y, softmax(logits))
        return {'train_loss': loss_value, "train_acc":acc}
    
    def test_step(self, data):
        x, y = data
        softmax = keras.layers.Softmax()
        logits = self.conv_model(x)
        loss_value = self.compiled_loss(y, softmax(logits))
        val_acc = self.compiled_acc(y, softmax(logits))  
        return {'loss':loss_value, "val_acc":val_acc}

    def call(self, inputs):
        return self.conv_model(inputs)

Expected behavior

I think the custom training process works just as well as using model.fit directly.

Environment

TensorFlow version: 2.4.0 Larq version: 0.12.0

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
lgeigercommented, Nov 2, 2021

As Jelmer, mentioned above it would be great to have a minimal reproducible example to be able to debug this.

Just to verify that the custom training loop is working as expected, could you try setting changing: logits = self.conv_model(x) to logits = self.conv_model(x, training=True)?

0reactions
lgeigercommented, Nov 3, 2021

Sounds great! Glad this fixes it for you.

Read more comments on GitHub >

github_iconTop Results From Across the Web

21 Questions to ask before Designing Any Training Program
These questions can help you identify the source of a performance problem and avoid building a training program that is doomed to fail...
Read more >
8 Questions to Ask When Choosing Compliance Training
Can the training be customized to your organization's industry, culture, brand and internal policies and procedures? In its guidance for ...
Read more >
3 Questions to Ask Before Building Your Customer Training ...
1. Which is better: on-demand or instructor-led training? · 2. Do you want to include pre-built or custom content? · 3. Should you...
Read more >
99 Questions to Include in a Post-training Evaluation Survey
Here are the 99 post-training evaluation questions to ask your learners for ... that's easy to set up, easy to use, and easy...
Read more >
MIT Training & Development | Key Questions Form
Training Design Factors: Key Questions ; Is this a new training? ; Is there an existing training that needs to be updated? ;...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found