Computing Top-K accuracy on validation data is unproportionately slow
See original GitHub issueI use TFRS in an e-commerce retail setting with a lot of purchase history data as well as click-stream data, using a multitask recommender model. I have created a model that works fine to train, evaluate and serve, with one unfortunate issue: validation during training takes 5x the time of the training, on a fraction of the data. To exemplify, I have 10M rows of interactions, that I split into 70/20/10 split for training/validation/test. There are 1M unique users and 100k unique items.
I train using 4 GPU’s, and one epoch takes ~4 minutes to run through the 7M training rows, but another ~20 minutes to get Top-K accuracy on the validation data (2M rows). During training, I make sure to specify compute_metrics=not training.
My retrieval task is set up as follows, where I use as large batch-size as my hardware allows to speed it up as much as possible.
self.retrieval_task: tf.keras.layers.Layer = tfrs.tasks.Retrieval(
loss=...,
metrics=tfrs.metrics.FactorizedTopK(candidates=items_ds.batch(8192).map(self.item_model))
)
And I train as follows:
model.fit(train.batch(BATCH_SIZE),
epochs=15,
validation_data=val.batch(BATCH_SIZE)
)
If I run 15 epochs I spend 1 hour training and 5 hours validating, on one seventh of the data. Is this immense time difference expected? Anything I can do to improve the performance during validation?
Issue Analytics
- State:
- Created 2 years ago
- Comments:7

Top Related StackOverflow Question
I’d always use a real metric that I care about - the top-K accuracy in this case.
Hi @hkristof03,
I’m certainly not an expert in this area and you may have already implemented it in this way, but I’ll share my thoughts. It might help if you could post a code sample of the
val_stepand where you are computing the new candidates before each validation epoch.Since you only have 50K items, you can probably get away with the BruteForce Index.
My approach would be to:
tf.data.Dataset.brute_force.index_from_dataset(candidate_ds.batch(512).map(candidate_model)). These vectors will initially be random.train_stepandtest_stepoutside of your training loop. Theval_stepwill use the index created above. Decorated these with tf.function.tf.functiondecorator.I think this should avoid a memory explosion because the candidates are stored as
tf.Variableand re-indexing will update the state of those variables, rather than creating new nodes in the graph. I’m not exactly sure how the ScaNN layer will behave, so I’d try the BruteForce first.Let me know if this works 🤞