question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Prevent overfitting in models.

See original GitHub issue

Hi Maciej,

I have been building a TFRS model, which is quite simple and my query/candidate matrix is small(ish) and quite dense, I have around 2MM query items and only 1K candidates.

The issue I am seeing is overfitting, after only 1/2 epochs of training, the model itself is kept as simple as possible with mainly just the embedding layer and some dense layers following the embedding, and one more dense layer in the ranking model.

I am also using BERT as the feature embedding, rather than training my own word vectors. Please see the full model below.

I have tried:

  • using retrieval task, even though I do not need this, it helps slow down overfitting
  • various optimizers and learning rates
  • adding L2
  • adding dropout

Nothing seems to really make any real impact.

Would be great to get some tips on how to prevent TFRS models from overfitting.

Thanks!

class URLModel(tf.keras.Model):
    def __init__(self, embed_dims, unique_urls, url_text):
        super().__init__()
                
        self.url_embedding = tf.keras.Sequential([
            tf.keras.layers.experimental.preprocessing.StringLookup(
                vocabulary=unique_urls, mask_token=None
            ),
            tf.keras.layers.Embedding(len(unique_urls) + 1, embed_dims),
        ])
        
        self.url_text_embedding = self._build_url_text_model()
        
        self.dense_layers = tf.keras.Sequential([
            tf.keras.layers.Dense(32, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.0001)),
            tf.keras.layers.Dropout(0.5),
            tf.keras.layers.Dense(32)
        ])
    
    def _build_url_text_model(self):
        text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
        preprocessing_layer = hub.KerasLayer(pre_small_bert, name='preprocessing')
        encoder_inputs = preprocessing_layer(text_input)
        encoder = hub.KerasLayer(small_bert, trainable=False, name='BERT')
        outputs = encoder(encoder_inputs)
        net = outputs['pooled_output']
        return tf.keras.Model(text_input, net)
    
    def call(self, inputs):
        embedding = tf.concat([
            self.url_embedding(inputs['url']),
            self.url_text_embedding(inputs['text']),
        ], axis=1)
        return self.dense_layers(embedding)

    
class AdvertiserModel(tf.keras.Model):
    def __init__(self, embed_dims, unique_advertisers):
        super().__init__()
        
        self.advertiser_embedding = tf.keras.Sequential([
            tf.keras.layers.experimental.preprocessing.StringLookup(
                vocabulary=unique_advertisers, mask_token=None
            ),
            tf.keras.layers.Embedding(len(unique_advertisers) + 1, embed_dims)
        ])

    def call(self, inputs):
        embedding = self.advertiser_embedding(inputs)
        return embedding


class RecModel(tfrs.models.Model):
    def __init__(self, embed_dims, unique_urls, url_text, unique_advertisers, rating_weight, retrieval_weight):
        super().__init__()
        
        self.query_model = tf.keras.Sequential([
            URLModel(embed_dims, unique_urls, url_text),
            tf.keras.layers.Dense(embed_dims)
        ])
        
        self.candidate_model = tf.keras.Sequential([
            AdvertiserModel(embed_dims, unique_advertisers),
            tf.keras.layers.Dense(embed_dims)
        ])

        self.rating_model = tf.keras.Sequential([
            tf.keras.layers.Dense(8, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.0001)),
            tf.keras.layers.Dense(1),
        ])

        self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
            loss=tf.keras.losses.MeanSquaredError(),
            metrics=[
                tf.keras.metrics.RootMeanSquaredError(),
                tf.keras.metrics.MeanSquaredError()
            ],
        )     

        self.retrieval_task = tfrs.tasks.Retrieval(
            metrics=tfrs.metrics.FactorizedTopK(
                candidates=tf.data.Dataset.from_tensor_slices(unique_advertisers).batch(128).map(self.candidate_model),
            ),
        )

        self.rating_weight = rating_weight
        self.retrieval_weight = retrieval_weight

    def call(self, features):
        url_embeddings = self.query_model({
            'url': features['url'],
            'text': features['text'],
        })
        advertiser_embeddings = self.candidate_model(features['advertiser_name'])
        return (
            url_embeddings,
            advertiser_embeddings,
            self.rating_model(
                tf.concat([url_embeddings, advertiser_embeddings], axis=1)
            ),
        )   

    def compute_loss(self, features, training=False):
        ratings = features.pop('value')

        url_embeddings, advertiser_embeddings, rating_predictions = self(features)

        rating_loss = self.rating_task(
            labels=ratings,
            predictions=rating_predictions,
        )
        retrieval_loss = self.retrieval_task(url_embeddings, advertiser_embeddings)
        
        return (self.rating_weight * rating_loss + self.retrieval_weight * retrieval_loss)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:2
  • Comments:6

github_iconTop GitHub Comments

2reactions
maciejkulacommented, Jan 17, 2021

I can’t see anything obviously wrong with your code; I suspect the issue lies with how you define your URL features.

Have you tried not using the full unique URL features? It may be responsible for most of your overfitting. Suppose I were trying to build a model that tried to predict the author of a tweet by using the full text (or the unique url) of that tweet as a feature. It would be 100% accurate in training but no better than random in validation. On the other hand, using words of the tweet (or other generalizable text features) would be far more successful.

0reactions
ydennisycommented, Jan 29, 2021

Ok thanks @maciejkula will close this - appreciate your help!

Read more comments on GitHub >

github_iconTop Results From Across the Web

8 Simple Techniques to Prevent Overfitting
1. Hold-out · 2. Cross-validation · 3. Data augmentation · 4. Feature selection · 5. L1 / L2 regularization · 6. Remove layers...
Read more >
Overfitting in Machine Learning: What It Is and How to Prevent It
How to Prevent Overfitting in Machine Learning · Cross-validation · Train with more data · Remove features · Early stopping · Regularization.
Read more >
5 Techniques to Prevent Overfitting in Neural Networks
1. Simplifying The Model ... The first step when dealing with overfitting is to decrease the complexity of the model. To decrease the...
Read more >
HOW TO AVOID OVERFITTING YOUR MODEL
One of the ways to prevent overfitting is by Simplifying the model. we can reduce overfitting by decreasing the complexity of the model...
Read more >
How to avoid overfitting in machine learning models
Training a model often and with variety coupled with formatting forgetting functions and separate test data sets are all effective measures ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found