Prevent overfitting in models.
See original GitHub issueHi Maciej,
I have been building a TFRS model, which is quite simple and my query/candidate matrix is small(ish) and quite dense, I have around 2MM query items and only 1K candidates.
The issue I am seeing is overfitting, after only 1/2 epochs of training, the model itself is kept as simple as possible with mainly just the embedding layer and some dense layers following the embedding, and one more dense layer in the ranking model.
I am also using BERT as the feature embedding, rather than training my own word vectors. Please see the full model below.
I have tried:
- using retrieval task, even though I do not need this, it helps slow down overfitting
- various optimizers and learning rates
- adding L2
- adding dropout
Nothing seems to really make any real impact.
Would be great to get some tips on how to prevent TFRS models from overfitting.
Thanks!
class URLModel(tf.keras.Model):
def __init__(self, embed_dims, unique_urls, url_text):
super().__init__()
self.url_embedding = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=unique_urls, mask_token=None
),
tf.keras.layers.Embedding(len(unique_urls) + 1, embed_dims),
])
self.url_text_embedding = self._build_url_text_model()
self.dense_layers = tf.keras.Sequential([
tf.keras.layers.Dense(32, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.0001)),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(32)
])
def _build_url_text_model(self):
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
preprocessing_layer = hub.KerasLayer(pre_small_bert, name='preprocessing')
encoder_inputs = preprocessing_layer(text_input)
encoder = hub.KerasLayer(small_bert, trainable=False, name='BERT')
outputs = encoder(encoder_inputs)
net = outputs['pooled_output']
return tf.keras.Model(text_input, net)
def call(self, inputs):
embedding = tf.concat([
self.url_embedding(inputs['url']),
self.url_text_embedding(inputs['text']),
], axis=1)
return self.dense_layers(embedding)
class AdvertiserModel(tf.keras.Model):
def __init__(self, embed_dims, unique_advertisers):
super().__init__()
self.advertiser_embedding = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=unique_advertisers, mask_token=None
),
tf.keras.layers.Embedding(len(unique_advertisers) + 1, embed_dims)
])
def call(self, inputs):
embedding = self.advertiser_embedding(inputs)
return embedding
class RecModel(tfrs.models.Model):
def __init__(self, embed_dims, unique_urls, url_text, unique_advertisers, rating_weight, retrieval_weight):
super().__init__()
self.query_model = tf.keras.Sequential([
URLModel(embed_dims, unique_urls, url_text),
tf.keras.layers.Dense(embed_dims)
])
self.candidate_model = tf.keras.Sequential([
AdvertiserModel(embed_dims, unique_advertisers),
tf.keras.layers.Dense(embed_dims)
])
self.rating_model = tf.keras.Sequential([
tf.keras.layers.Dense(8, activation='relu', kernel_regularizer=tf.keras.regularizers.L2(0.0001)),
tf.keras.layers.Dense(1),
])
self.rating_task: tf.keras.layers.Layer = tfrs.tasks.Ranking(
loss=tf.keras.losses.MeanSquaredError(),
metrics=[
tf.keras.metrics.RootMeanSquaredError(),
tf.keras.metrics.MeanSquaredError()
],
)
self.retrieval_task = tfrs.tasks.Retrieval(
metrics=tfrs.metrics.FactorizedTopK(
candidates=tf.data.Dataset.from_tensor_slices(unique_advertisers).batch(128).map(self.candidate_model),
),
)
self.rating_weight = rating_weight
self.retrieval_weight = retrieval_weight
def call(self, features):
url_embeddings = self.query_model({
'url': features['url'],
'text': features['text'],
})
advertiser_embeddings = self.candidate_model(features['advertiser_name'])
return (
url_embeddings,
advertiser_embeddings,
self.rating_model(
tf.concat([url_embeddings, advertiser_embeddings], axis=1)
),
)
def compute_loss(self, features, training=False):
ratings = features.pop('value')
url_embeddings, advertiser_embeddings, rating_predictions = self(features)
rating_loss = self.rating_task(
labels=ratings,
predictions=rating_predictions,
)
retrieval_loss = self.retrieval_task(url_embeddings, advertiser_embeddings)
return (self.rating_weight * rating_loss + self.retrieval_weight * retrieval_loss)
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:6
Top Results From Across the Web
8 Simple Techniques to Prevent Overfitting
1. Hold-out · 2. Cross-validation · 3. Data augmentation · 4. Feature selection · 5. L1 / L2 regularization · 6. Remove layers...
Read more >Overfitting in Machine Learning: What It Is and How to Prevent It
How to Prevent Overfitting in Machine Learning · Cross-validation · Train with more data · Remove features · Early stopping · Regularization.
Read more >5 Techniques to Prevent Overfitting in Neural Networks
1. Simplifying The Model ... The first step when dealing with overfitting is to decrease the complexity of the model. To decrease the...
Read more >HOW TO AVOID OVERFITTING YOUR MODEL
One of the ways to prevent overfitting is by Simplifying the model. we can reduce overfitting by decreasing the complexity of the model...
Read more >How to avoid overfitting in machine learning models
Training a model often and with variety coupled with formatting forgetting functions and separate test data sets are all effective measures ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

I can’t see anything obviously wrong with your code; I suspect the issue lies with how you define your URL features.
Have you tried not using the full unique URL features? It may be responsible for most of your overfitting. Suppose I were trying to build a model that tried to predict the author of a tweet by using the full text (or the unique url) of that tweet as a feature. It would be 100% accurate in training but no better than random in validation. On the other hand, using words of the tweet (or other generalizable text features) would be far more successful.
Ok thanks @maciejkula will close this - appreciate your help!