Can't deserialize model using __add__ or Add()
See original GitHub issueClone of https://github.com/tensorflow/tensorflow/issues/57574
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
- TensorFlow installed from (source or binary): source
- TensorFlow version (use command below): v2.9.1-0-gd8ce9f9c301 2.9.1
- Python version: 3.9.13
- Bazel version (if compiling from source): 5.0.0
- GPU model and memory: nvidia A100-SXM4-40GB
- Exact command to reproduce:
python repro.py
Describe the problem.
When I declare a model as such:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Layer
class MyLayer(Layer):
def __init__(self):
super().__init__()
self.dense = Dense(512)
def __call__(self, X):
short = X
X = self.dense(X)
X = short + X
return X
def main():
model = tf.keras.Sequential([MyLayer()])
model.build([None, 512])
model.save("/tmp/my_model")
tf.keras.models.load_model("/tmp/my_model")
if __name__ == "__main__":
main()
Describe the current behavior.
I get an exception trying to load the model:
api_dispatcher.Dispatch=<bound method PyCapsule.Dispatch of <Dispatch(_add_dispatch): >, args=(<tf.Tensor 'Placeholder:0' shape=(None, 512) dtype=float32>,), kwargs={}
Traceback (most recent call last):
File "/mnt/disks/data/src/repro.py", line 24, in <module>
main()
File "/mnt/disks/data/src/repro.py", line 21, in main
tf.keras.models.load_model("/tmp/my_model")
File "/mnt/disks/data/src/conda/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/mnt/disks/data/src/conda/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1078, in op_dispatch_handler
result = api_dispatcher.Dispatch(args, kwargs)
TypeError: Missing required positional argument
Describe the expected behavior.
The model loads fine.
- Do you want to contribute a PR? (yes/no): no
- If yes, please read this page for instructions
- Briefly describe your candidate solution(if contributing):
Standalone code to reproduce the issue.
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Cannot deserialize the JSON array (e.g. [1,2,3]) into type ...
To fix this error either change the JSON to a JSON object (e.g. {"name":"value"}) or change the deserialized type to an array or...
Read more >Serialize and deserialize json, cannot ... - CodeProject
To fix this error either change the JSON to a JSON array (e.g. [1,2,3]) or change the deserialized type so that it is...
Read more >How to serialize and deserialize JSON using C# - .NET
This article shows how to use the System.Text.Json namespace to serialize to and deserialize from JavaScript Object Notation (JSON).
Read more >Cannot deserialize the current JSON object (e.g. {"name ...
To fix this error either change the JSON to a JSON array (e.g. [1,2,3]) or change the deserialized type so that it is...
Read more >Cannot deserialize the current JSON object - OutSystems
To fix this error either change the JSON to a JSON array (e.g. [1,2,3]) or change the deserialized type so that it is...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hi Froody!
It looks like you are overriding
__call__
directly instead of thecall
method, which is what subclasses ofLayer
should override (see Layer docs)I modified the gist from tlakrayal@ to rename your override from
__call__
tocall
and it now appears to be working as intended.I hope this unblocks you for now, and I am going to look into options for how to better prevent this for future users.
Thanks!
I don’t think there is a clean way of preventing accidental override of
__call__
because there are valid use cases for overriding it.I’ve looked through the Layer docs again and between the docstrings about
call
and__call__
and the examples, I think there is currently sufficient guidance toward overridingcall
for use cases like this one.