question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Computing Top-K accuracy on validation data is unproportionately slow

See original GitHub issue

I use TFRS in an e-commerce retail setting with a lot of purchase history data as well as click-stream data, using a multitask recommender model. I have created a model that works fine to train, evaluate and serve, with one unfortunate issue: validation during training takes 5x the time of the training, on a fraction of the data. To exemplify, I have 10M rows of interactions, that I split into 70/20/10 split for training/validation/test. There are 1M unique users and 100k unique items.

I train using 4 GPU’s, and one epoch takes ~4 minutes to run through the 7M training rows, but another ~20 minutes to get Top-K accuracy on the validation data (2M rows). During training, I make sure to specify compute_metrics=not training.

My retrieval task is set up as follows, where I use as large batch-size as my hardware allows to speed it up as much as possible.

self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
    loss=...,
    metrics=tfrs.metrics.FactorizedTopK(candidates=items_ds.batch(8192).map(self.item_model))
)

And I train as follows:

model.fit(train.batch(BATCH_SIZE),
          epochs=15,
          validation_data=val.batch(BATCH_SIZE)
)

If I run 15 epochs I spend 1 hour training and 5 hours validating, on one seventh of the data. Is this immense time difference expected? Anything I can do to improve the performance during validation?

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:7

github_iconTop GitHub Comments

2reactions
maciejkulacommented, Oct 18, 2021

I’d always use a real metric that I care about - the top-K accuracy in this case.

0reactions
patrickorlandocommented, Sep 17, 2022

Hi @hkristof03,

I’m certainly not an expert in this area and you may have already implemented it in this way, but I’ll share my thoughts. It might help if you could post a code sample of the val_step and where you are computing the new candidates before each validation epoch.

Since you only have 50K items, you can probably get away with the BruteForce Index.

My approach would be to:

  1. Create or load your candidates in a tf.data.Dataset.
  2. Create a BruteForce layer and index it with brute_force.index_from_dataset(candidate_ds.batch(512).map(candidate_model)). These vectors will initially be random.
  3. Define the train_step and test_step outside of your training loop. The val_step will use the index created above. Decorated these with tf.function.
  4. Before each validation epoch, re-index the retrieval layer (don’t create a new layer each time), use the same candidate_ds. Ensure this function is not wrapped in a tf.function decorator.

I think this should avoid a memory explosion because the candidates are stored as tf.Variable and re-indexing will update the state of those variables, rather than creating new nodes in the graph. I’m not exactly sure how the ScaNN layer will behave, so I’d try the BruteForce first.

Let me know if this works 🤞

Read more comments on GitHub >

github_iconTop Results From Across the Web

Why validation accuracy is increasing very slowly?
I generated the correct data and the problem was solved to some extent (The validation accuracy increased around 60%). Then finally I improved ......
Read more >
Chapter 4. Common Model Parameters - O'Reilly
When evaluating on the test data set, that threshold will be calculated as the average of the threshold that gave maximum accuracy on...
Read more >
Things to Know about The Top 1 and Top 5 Accuracy
Let assume that we're working on a simple classification problem using deep learning. We gave the picture (blueberry) as an input to the ......
Read more >
Optimization metrics: DataRobot docs
Display Full name Project type Accuracy Accuracy Binary classification, multicla... AUC/Weighted AUC Area Under the (ROC) Curve Binary classification, multicla... Area Under PR Curve Area Under...
Read more >
Machine Learning Glossary - Google Developers
Compare and contrast accuracy with precision and recall. ... are not present in validation data, then co-adaptation causes overfitting.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found