diff --git a/tensorflow_recommenders/tasks/retrieval.py b/tensorflow_recommenders/tasks/retrieval.py index bd44e05..bc219c5 100644 --- a/tensorflow_recommenders/tasks/retrieval.py +++ b/tensorflow_recommenders/tasks/retrieval.py @@ -199,8 +199,8 @@ def call(self, query_embeddings, # Slice to the size of query embeddings # if `candidate_embeddings` contains extra negatives. - candidate_embeddings[:tf.shape(query_embeddings)[0]], - true_candidate_ids=candidate_ids) + candidate_embeddings[:num_queries], + true_candidate_ids=candidate_ids[:num_queries]) ) if compute_batch_metrics: