Why does the Retrieval class use an identify matrix for labels?
See original GitHub issueHi there,
Thanks so much for releasing and maintaining this code base. It’s really fantastic.
I don’t really have a bug to report, but I do have a question regarding the Retrieval class and how the loss function is being calculated. I’ve been working through the tutorials, focusing on the basic retrieval example and I understand that by default, the loss function uses a categorical cross entropy loss function.
Obviously, this implies having a label and predicted probabilities which I can see in the Retrieval class.
From that class, the scores are the matrix multiplication of the query and candidate embeddings:
scores = tf.linalg.matmul(
query_embeddings, candidate_embeddings, transpose_b=True)
Then, the labels are derived as:
labels = tf.eye(num_queries, num_candidates)
Which is then passed to tf.keras.losses.CategoricalCrossentropy to calculate the loss:
loss = self._loss(y_true=labels, y_pred=scores, sample_weight=sample_weight)
What I don’t quite understand is why is the identity matrix used as the labels? Doesn’t this imply that user_i has selected candidate_i? Or am I missing something?
If I relate this back to the basic retrieval example, then would the scores matrix be of size number_unique_user_ids x number_unique_movie_ids? Likewise, the rows of the loss matrix would relate to a user_id and the columns to a candidate. Wouldn’t this imply that user_1 reviewed candidate_1, etc…?
Apologies if this is a fairly basic question, but I’m quite new to Tensorflow. Would appreciate any feedback or references. I’ve tried looking at the issues here and also on stackoverflow, but haven’t really been able to find anything. Thanks.
Issue Analytics
- State:
- Created 2 years ago
- Comments:11

Top Related StackOverflow Question
Agree.
Not quite.
Let’s imagine you have 3 users and 4 items and that the positive interactions are as follows
As an interaction matrix you have
Assume a batch size of 2, and a query/candidate embedding of size 12. We randomly sample 2 rows from the table above, rows 1 and 5.
Your queries will be a matrix of shape (2, 12). The first row will the the query for user 1, and the second row for user 3. Your candidates will be a matrix of shape (2, 12). The first row will be for item 3, and the second for item 2.
When you perform the dot product between queries and candidates you get a score matrix of shape (2, 2). The rows are the users, and the columns are the items.
The diagonal of this matrix is the score for the positive interactions that we sampled. All the other elements are the scores for that query and the positive items for other examples, which we then use as negatives for that example
This matrix is not representative of the global interaction matrix. Consider we sample a batch of rows 1 and 4 instead. In this case, both users 1 and 2 interacted with item 3. The score matrix is then,
Here, the negative for each positive case is the same as the positive item. This is an accidental hit, and the tfrs library has the ability parameter to remove these hits if you pass the candidate ids in.
So in summary, each row of the scores matrix corresponds to is a single (user, item) pair. The diagonal is the score for that pair and all other columns are negatives sampled from the other pairs in that same mini-batch.
Thank you for the wonderful explanations, @patrickorlando!