question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[BUG] Error when apply two-tower model to customized dataset

See original GitHub issue

Bug description

When applying customized dataset to two-tower model, ran into issues getting the ragged shapes to line up with the built in models.

Steps/Code to reproduce bug

Code:
model = mm.TwoTowerModel(
    schema,
    query_tower=mm.MLPBlock([128, 64], no_activation_last_layer=True),
    item_tower=mm.MLPBlock([128, 64], no_activation_last_layer=True),
    samplers=[mm.InBatchSampler()],
    embedding_options=mm.EmbeddingOptions(infer_embedding_sizes=True),
)
model.compile(optimizer="adam", run_eagerly=False, metrics=[mm.RecallAt(10), mm.NDCGAt(10)])
model.fit(train, validation_data=valid, batch_size=128, epochs=3)
Error:
2022-08-10 17:59:42.046220: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
 
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [11], in <cell line: 1>()
----> 1 model.fit(train, validation_data=valid, batch_size=128, epochs=3)

File /models/merlin/models/tf/models/base.py:351, in Model.fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing, train_metrics_steps, **kwargs)
    345     validation_data = BatchedDataset(
    346         validation_data, batch_size=batch_size, shuffle=False, **kwargs
    347     )
    349 callbacks = self._add_metrics_callback(callbacks, train_metrics_steps)
--> 351 return super().fit(
    352    x,
    353    y,
    354    batch_size,
    355    epochs,
    356    verbose,
    357    callbacks,
    358    validation_split,
    359    validation_data,
    360    shuffle,
    361    class_weight,
    362    sample_weight,
    363    initial_epoch,
    364    steps_per_epoch,
    365    validation_steps,
    366    validation_batch_size,
    367    validation_freq,
    368    max_queue_size,
    369    workers,
    370    use_multiprocessing,
    371 )

File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     65 except Exception as e:  # pylint: disable=broad-except
     66   filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67   raise e.with_traceback(filtered_tb) from None
     68 finally:
     69   del filtered_tb

File /models/merlin/models/tf/models/base.py:125, in Model.call(self, inputs, **kwargs)
    124 def call(self, inputs, **kwargs):
--> 125     outputs = self.block(inputs, **kwargs)
    126     return outputs

File /models/merlin/models/config/schema.py:55, in SchemaMixin.__call__(self, *args, **kwargs)
     52 def __call__(self, *args, **kwargs):
     53     self.check_schema()
---> 55     return super().__call__(*args, **kwargs)

File /models/merlin/models/tf/blocks/core/base.py:275, in Block._maybe_build(self, inputs)
    272 if getattr(self, "_context", None) and not self.context.built:
    273     self.context.set_dtypes(inputs)
--> 275 super()._maybe_build(inputs)

File /models/merlin/models/tf/blocks/core/combinators.py:109, in SequentialBlock.build(self, input_shape)
    107 for layer in self.layers:
    108     try:
--> 109         layer.build(input_shape)
    110     except TypeError:
    111         t, v, tb = sys.exc_info()

File /models/merlin/models/tf/blocks/core/combinators.py:414, in ParallelBlock.build(self, input_shape)
    412 else:
    413     for layer in self.parallel_values:
--> 414         layer.build(input_shape)
    416 return super().build(input_shape)

File /models/merlin/models/tf/blocks/core/combinators.py:109, in SequentialBlock.build(self, input_shape)
    107 for layer in self.layers:
    108     try:
--> 109         layer.build(input_shape)
    110     except TypeError:
    111         t, v, tb = sys.exc_info()

File /models/merlin/models/tf/models/base.py:60, in ModelBlock.build(self, input_shapes)
     59 def build(self, input_shapes):
---> 60     self.block.build(input_shapes)
     62     if not hasattr(self.build, "_is_default"):
     63         self._build_input_shape = input_shapes

File /models/merlin/models/tf/blocks/core/combinators.py:109, in SequentialBlock.build(self, input_shape)
    107 for layer in self.layers:
    108     try:
--> 109         layer.build(input_shape)
    110     except TypeError:
    111         t, v, tb = sys.exc_info()

File /models/merlin/models/tf/blocks/core/combinators.py:109, in SequentialBlock.build(self, input_shape)
    107 for layer in self.layers:
    108     try:
--> 109         layer.build(input_shape)
    110     except TypeError:
    111         t, v, tb = sys.exc_info()

File /models/merlin/models/tf/blocks/core/combinators.py:118, in SequentialBlock.build(self, input_shape)
    113             v = TypeError(
    114                 f"Couldn't build {layer}, "
    115                 f"did you forget to add aggregation to {last_layer}?"
    116             )
    117         six.reraise(t, v, tb)
--> 118     input_shape = layer.compute_output_shape(input_shape)
    119     last_layer = layer
    120 self.built = True

File /models/merlin/models/tf/blocks/mlp.py:211, in _Dense.compute_output_shape(self, input_shape)
    209 if isinstance(input_shape, dict):
    210     agg = tabular_aggregation_registry.parse(self.pre_aggregation)
--> 211     input_shape = agg.compute_output_shape(input_shape)
    213 return super(_Dense, self).compute_output_shape(input_shape)

File /models/merlin/models/tf/blocks/core/aggregation.py:69, in ConcatFeatures.compute_output_shape(self, input_shapes)
     67 agg_dim = sum([i[-1] for i in input_shapes.values()])
     68 if isinstance(agg_dim, tf.TensorShape):
---> 69     raise ValueError(f"Not possible to aggregate, received: {input_shapes}.")
     70 output_size = self._get_agg_output_size(input_shapes, agg_dim)
     71 return output_size

ValueError: Exception encountered when calling layer "retrieval_model" (type RetrievalModel).

Not possible to aggregate, received: {'duration_ms_songs_pl': (TensorShape([6677, 1]), TensorShape([128, 1])), 'artist_pop_pl': (TensorShape([6677, 1]), TensorShape([128, 1])), 'artists_followers_pl': (TensorShape([6677, 1]), TensorShape([128, 1])), 'track_pop_pl': (TensorShape([6677, 1]), TensorShape([128, 1])), 'duration_seed_track': TensorShape([128, 1]), 'track_pop_seed_track': TensorShape([128, 1]), 'artist_pop_seed_track': TensorShape([128, 1]), 'artist_followers_seed_track': TensorShape([128, 1]), 'duration_ms_seed_pl': TensorShape([128, 1]), 'n_songs_pl': TensorShape([128, 1]), 'num_artists_pl': TensorShape([128, 1]), 'num_albums_pl': TensorShape([128, 1]), 'artist_name_pl': TensorShape([128, 43]), 'track_name_pl': TensorShape([128, 65]), 'album_name_pl': TensorShape([128, 52]), 'artist_genres_pl': TensorShape([128, 28]), 'artist_name_seed_track': TensorShape([128, 31]), 'artist_uri_seed_track': TensorShape([128, 32]), 'track_name_seed_track': TensorShape([128, 43]), 'track_uri_seed_track': TensorShape([128, 46]), 'album_name_seed_track': TensorShape([128, 37]), 'album_uri_seed_track': TensorShape([128, 39]), 'artist_genres_seed_track': TensorShape([128, 24]), 'description_pl': TensorShape([128, 21]), 'name': TensorShape([128, 33]), 'collaborative': TensorShape([128, 3])}.

Call arguments received:
  • inputs={'artist_name_pl': ('tf.Tensor(shape=(6677, 1), dtype=int32)', 'tf.Tensor(shape=(128, 1), dtype=int32)'), 'track_name_pl': ('tf.Tensor(shape=(6677, 1), dtype=int32)', 'tf.Tensor(shape=(128, 1), dtype=int32)'), 'album_name_pl': ('tf.Tensor(shape=(6677, 1), dtype=int32)', 'tf.Tensor(shape=(128, 1), dtype=int32)'), 'artist_genres_pl': ('tf.Tensor(shape=(6677, 1), dtype=int32)', 'tf.Tensor(shape=(128, 1), dtype=int32)'), 'track_uri_can': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'artist_name_can': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'track_name_can': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'artist_genres_can': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'artist_name_seed_track': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'artist_uri_seed_track': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'track_name_seed_track': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'track_uri_seed_track': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'album_name_seed_track': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'album_uri_seed_track': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'artist_genres_seed_track': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'description_pl': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'name': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'collaborative': 'tf.Tensor(shape=(128, 1), dtype=int32)', 'duration_ms_songs_pl': ('tf.Tensor(shape=(6677, 1), dtype=float32)', 'tf.Tensor(shape=(128, 1), dtype=int32)'), 'artist_pop_pl': ('tf.Tensor(shape=(6677, 1), dtype=float32)', 'tf.Tensor(shape=(128, 1), dtype=int32)'), 'artists_followers_pl': ('tf.Tensor(shape=(6677, 1), dtype=float32)', 'tf.Tensor(shape=(128, 1), dtype=int32)'), 'track_pop_pl': ('tf.Tensor(shape=(6677, 1), dtype=float32)', 'tf.Tensor(shape=(128, 1), dtype=int32)'), 'duration_ms_can': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'track_pop_can': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'artist_pop_can': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'artist_followers_can': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'duration_seed_track': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'track_pop_seed_track': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'artist_pop_seed_track': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'artist_followers_seed_track': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'duration_ms_seed_pl': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'n_songs_pl': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'num_artists_pl': 'tf.Tensor(shape=(128, 1), dtype=float32)', 'num_albums_pl': 'tf.Tensor(shape=(128, 1), dtype=float32)'}
  • kwargs={'training': 'False'}

Expected behavior

merlin should accept the customized shape

Environment details

NGC merlin TF training container 22.04

Additional context

repo at: https://github.com/jswortz/spotify_mpd_two_tower

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
rnyakcommented, Aug 17, 2022

@mengdong Please see my input above. The error is due to feeding variable-length continuous list columns to the Two-Tower model. This is problematic for now due to the tuple representation of list columns. Currently we don’t support variable-length multi-hot continuous columns out of the box. This will be worked in near future.

0reactions
rnyakcommented, Sep 12, 2022
Read more comments on GitHub >

github_iconTop Results From Across the Web

A Dual Augmented Two-tower Model for Online Large-scale ...
However, the model suffers from lack of information interaction between the two towers. Besides, imbalanced category data also hinders the model performance.
Read more >
The 5-Step Recipe To Make Your Deep Learning Models Bug ...
Finally, poor model performance could be caused not by your model but your dataset construction. Common issues here include not having enough ...
Read more >
Personalized recommendations - IV (two tower models for ...
Two tower models to compute user and item embeddings. The aim here is to train two neural networks, a user encoder and an...
Read more >
Cross-Batch Negative Sampling for Training Two-Tower ...
In the batch training for two-tower models, using in-batch negatives [13, 36], i.e., taking positive items of other users in the same mini-batch ......
Read more >
xei/recommender-system-tutorial · GitHub
A step-by-step tutorial on developing a practical recommendation system (retrieval and ranking) using TensorFlow Recommenders and Keras.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found