question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Question] After saving & loading the TFmodel/ Scann/ BruteForce objects with Dict input for User Tower, the loaded model won't work properly

See original GitHub issue

About Save & Load BruteForce/Scann and Model object

I am playing this tutorial with a online shopping dataset, and followed the tutorial where the User tower is similar to this:

Class UserModel ():
...
...
def call(self, inputs):
    # Take the input dictionary, pass it through each input layer,
    # and concatenate the result.
    return tf.concat([
        self.user_embedding(inputs["user_id"]),
        self.timestamp_embedding(inputs["timestamp"]),
        tf.reshape(self.normalized_timestamp(inputs["timestamp"]), (-1, 1)),
    ], axis=1)

Where inputs is a dict.

I am able to get the embedding by UserModel()(input_dict) just fine. The issue is when I work with the example in this link: https://www.tensorflow.org/recommenders/examples/efficient_serving, where we want to save the Scann/BF object.

I am able to get the Scann working and able to call it

scann = tfrs.layers.factorized_top_k.ScaNN(model.user_model, num_reordering_candidates=100)
scann.index_from_dataset(
    sku_map.batch(2048).map(lambda x: (x["SKU_KEY"], model.sku_model(x)) ))
    # (sku_map.batch(2048).map(lambda x: x["SKU_KEY"]) , sku_map.batch(2048).map(model.sku_model) )
scann({'CONTEXT_ID': np.array([[b'263', b'34', b'555', b'44', b'3300']]) ,
 'USER_ID': np.array([b'sssksksksksksk'])
})

and it will return meaningful results. However, if I follow Deploying the approximate model section to save it and load it back, I got an error

ValueError: Could not find matching function to call loaded from the SavedModel. Got:
  Positional arguments (3 total):
    * {'CONTEXT_ID': <tf.Tensor 'queries:0' shape=(1, 5) dtype=string>, 'USER_ID': <tf.Tensor 'queries_1:0' shape=(1,) dtype=string>}
    * None
    * False
  Keyword arguments: {}

Expected these arguments to match one of the following 4 option(s):

Option 1:
  Positional arguments (3 total):
    * {'PRICE': TensorSpec(shape=(None,), dtype=tf.float32, name='PRICE'), 'USER_ID': TensorSpec(shape=(None,), dtype=tf.string, name='USER_ID'), 'TRANS_COUNT': TensorSpec(shape=(None,), dtype=tf.int64, name='TRANS_COUNT'), 'SKU_KEY': TensorSpec(shape=(None, 1), dtype=tf.string, name='SKU_KEY'), 'SKU_DESC': TensorSpec(shape=(None,), dtype=tf.string, name='SKU_DESC'), 'CONTEXT_ID': TensorSpec(shape=(None, 5), dtype=tf.string, name='CONTEXT_ID')}
    * None
    * False
  Keyword arguments: {}

Seems to me 1, the shape is all (None,) 2, it can not identify the input dict anymore… same thing happened if I tried to save the model and load it back

model.save('my_model')
my_tf_saved_model = tf.keras.models.load_model(
    './my_model')
my_tf_saved_model(row)

It would throw similar errors but the model(row) (row is a dict) works fine…

Can’t do model.evaluate after replacing factorized_metrics with BruteForce

Another strange finding is, with the above setup, if I define the Query model in the BruteForce (brute_force = tfrs.layers.factorized_top_k.BruteForce(model.user_model) ), and then reset factorized_metrics and then do the model.evaluate (for fast performance), it will give me an error

    TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'USER_ID'

seems it does not like the way how I let User Tower’s input as a dict and return self.user_model(inputs['USER_ID']). However, It will work if I do not specify the User_Model when initializing the BruteForce function.

Any insights would be appreciated!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:16

github_iconTop GitHub Comments

3reactions
patrickorlandocommented, Dec 9, 2021

@xiaoyaoyang When serialising a model, tensorflow creates a strict function call signature based on the tracing the model. Before serialising you have passed in a dict containing.

{'PRICE': TensorSpec(shape=(None,), dtype=tf.float32, name='PRICE'), 'USER_ID': TensorSpec(shape=(None,), dtype=tf.string, name='USER_ID'), 'TRANS_COUNT': TensorSpec(shape=(None,), dtype=tf.int64, name='TRANS_COUNT'), 'SKU_KEY': TensorSpec(shape=(None, 1), dtype=tf.string, name='SKU_KEY'), 'SKU_DESC': TensorSpec(shape=(None,), dtype=tf.string, name='SKU_DESC'), 'CONTEXT_ID': TensorSpec(shape=(None, 5), dtype=tf.string, name='CONTEXT_ID')}

When you serialise your model tensorflow will create a call signature that expects all those inputs, even if your model doesn’t use them. So when you call it with just USER_ID, it will fail.

It’s best to ensure you only pass the required features into your model during training and evaluation. Alternatively you should be able to resolve this by calling the model once with an example record with only the required features before serialising. This will then result in another call signature that matches the input you expect to pass when serving.

1reaction
patrickorlandocommented, Jan 21, 2022

I mean getting a prediction when I say call the model. After you index the brute_force layer, you need to run.

scores, identifiers = index(example_record)
Read more comments on GitHub >

github_iconTop Results From Across the Web

Scann/ BruteForce objects with Dict input for User Tower, the ...
[Question] After saving & loading the TFmodel/ Scann/ BruteForce objects with Dict input for User Tower, the loaded model won't work properly.
Read more >
tf.keras.models.load_model() does not load saved ... - GitHub
Steps to reproduce the behavior: Load the model with TFOpenAIGPTLMHeadModel; Add input layers; save the model; Load saved model. from ...
Read more >
Efficient serving | TensorFlow Recommenders
Given a database of candidate embeddings, ScaNN indexes these embeddings in a manner that allows them to be rapidly searched at inference time....
Read more >
Can't save and load a model - keras - Stack Overflow
when i want to do the same things with the "SavedModel" format I can save the model but when i try to load...
Read more >
Biopython Tutorial and Cookbook
This holds a sequence (as a Seq object) with additional annotation including an identifier, name and description. The Bio.SeqIO module for ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found