[Question] After saving & loading the TFmodel/ Scann/ BruteForce objects with Dict input for User Tower, the loaded model won't work properly
See original GitHub issueAbout Save & Load BruteForce/Scann and Model object
I am playing this tutorial with a online shopping dataset, and followed the tutorial where the User tower is similar to this:
Class UserModel ():
...
...
def call(self, inputs):
# Take the input dictionary, pass it through each input layer,
# and concatenate the result.
return tf.concat([
self.user_embedding(inputs["user_id"]),
self.timestamp_embedding(inputs["timestamp"]),
tf.reshape(self.normalized_timestamp(inputs["timestamp"]), (-1, 1)),
], axis=1)
Where inputs is a dict.
I am able to get the embedding by UserModel()(input_dict) just fine. The issue is when I work with the example in this link: https://www.tensorflow.org/recommenders/examples/efficient_serving, where we want to save the Scann/BF object.
I am able to get the Scann working and able to call it
scann = tfrs.layers.factorized_top_k.ScaNN(model.user_model, num_reordering_candidates=100)
scann.index_from_dataset(
sku_map.batch(2048).map(lambda x: (x["SKU_KEY"], model.sku_model(x)) ))
# (sku_map.batch(2048).map(lambda x: x["SKU_KEY"]) , sku_map.batch(2048).map(model.sku_model) )
scann({'CONTEXT_ID': np.array([[b'263', b'34', b'555', b'44', b'3300']]) ,
'USER_ID': np.array([b'sssksksksksksk'])
})
and it will return meaningful results. However, if I follow Deploying the approximate model section to save it and load it back, I got an error
ValueError: Could not find matching function to call loaded from the SavedModel. Got:
Positional arguments (3 total):
* {'CONTEXT_ID': <tf.Tensor 'queries:0' shape=(1, 5) dtype=string>, 'USER_ID': <tf.Tensor 'queries_1:0' shape=(1,) dtype=string>}
* None
* False
Keyword arguments: {}
Expected these arguments to match one of the following 4 option(s):
Option 1:
Positional arguments (3 total):
* {'PRICE': TensorSpec(shape=(None,), dtype=tf.float32, name='PRICE'), 'USER_ID': TensorSpec(shape=(None,), dtype=tf.string, name='USER_ID'), 'TRANS_COUNT': TensorSpec(shape=(None,), dtype=tf.int64, name='TRANS_COUNT'), 'SKU_KEY': TensorSpec(shape=(None, 1), dtype=tf.string, name='SKU_KEY'), 'SKU_DESC': TensorSpec(shape=(None,), dtype=tf.string, name='SKU_DESC'), 'CONTEXT_ID': TensorSpec(shape=(None, 5), dtype=tf.string, name='CONTEXT_ID')}
* None
* False
Keyword arguments: {}
Seems to me 1, the shape is all (None,) 2, it can not identify the input dict anymore… same thing happened if I tried to save the model and load it back
model.save('my_model')
my_tf_saved_model = tf.keras.models.load_model(
'./my_model')
my_tf_saved_model(row)
It would throw similar errors but the model(row) (row is a dict) works fine…
Can’t do model.evaluate after replacing factorized_metrics with BruteForce
Another strange finding is, with the above setup, if I define the Query model in the BruteForce (brute_force = tfrs.layers.factorized_top_k.BruteForce(model.user_model) ), and then reset factorized_metrics and then do the model.evaluate (for fast performance), it will give me an error
TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got 'USER_ID'
seems it does not like the way how I let User Tower’s input as a dict and return self.user_model(inputs['USER_ID']). However, It will work if I do not specify the User_Model when initializing the BruteForce function.
Any insights would be appreciated!
Issue Analytics
- State:
- Created 2 years ago
- Comments:16

Top Related StackOverflow Question
@xiaoyaoyang When serialising a model, tensorflow creates a strict function call signature based on the tracing the model. Before serialising you have passed in a dict containing.
When you serialise your model tensorflow will create a call signature that expects all those inputs, even if your model doesn’t use them. So when you call it with just
USER_ID, it will fail.It’s best to ensure you only pass the required features into your model during training and evaluation. Alternatively you should be able to resolve this by calling the model once with an example record with only the required features before serialising. This will then result in another call signature that matches the input you expect to pass when serving.
I mean getting a prediction when I say call the model. After you index the brute_force layer, you need to run.