Unable to serialize/save TF2.0 RobertaSequenceClassification model to saved model format
See original GitHub issue🐛 Bug
I am getting an error while trying to serialize/save TF2.0 RobertaSequenceClassification Keras model to saved model format. I do not see this issue with Bert or Albert model architecture. Please see below for my test script that can be used to reproduce this issue.
Information
Model I am using (Bert, XLNet …): Roberta
Language I am using the model on (English, Chinese …): English
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
import tensorflow as tf
from transformers import *
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = TFRobertaForSequenceClassification.from_pretrained('roberta-base')
##########Uncomment the following 2 lines for testing with BERT ############
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :]
outputs = model(input_ids)
logits = outputs[0]
tf_saved_model_path= "/tmp/saved_model/"
tf.keras.models.save_model(model, tf_saved_model_path, overwrite=True, include_optimizer=False, save_format='tf')
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- my own task or dataset: (give details below) I need to export/serialize a TF Keras model to TF saved model format
To reproduce
Steps to reproduce the behavior:
- Run the script pasted above to reproduce the issue with Roberta
- Uncomment the 2 lines as mentioned in the script for using Bert (no error seen with Bert)
Stack Trace for Roberta
TypeError Traceback (most recent call last) <ipython-input-5-87e63ee0b3ac> in <module> 15 16 tf_saved_model_path= “/tmp/saved_model/” —> 17 tf.keras.models.save_model(model, tf_saved_model_path, overwrite=True, include_optimizer=False, save_format=‘tf’)
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options) 136 else: 137 saved_model_save.save(model, filepath, overwrite, include_optimizer, –> 138 signatures, options) 139 140
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options) 76 # we use the default replica context here. 77 with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access —> 78 save_lib.save(model, filepath, signatures, options) 79 80 if not include_optimizer:
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options) 949 950 _, exported_graph, object_saver, asset_info = _build_meta_graph( –> 951 obj, export_dir, signatures, options, meta_graph_def) 952 saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION 953
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, export_dir, signatures, options, meta_graph_def) 1035 1036 object_graph_proto = _serialize_object_graph(saveable_view, -> 1037 asset_info.asset_index) 1038 meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) 1039
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/saved_model/save.py in _serialize_object_graph(saveable_view, asset_file_def_index) 695 for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): 696 _write_object_proto(obj, obj_proto, asset_file_def_index, –> 697 saveable_view.function_name_map) 698 return proto 699
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/saved_model/save.py in _write_object_proto(obj, proto, asset_file_def_index, function_name_map) 735 version=versions_pb2.VersionDef( 736 producer=1, min_consumer=1, bad_consumers=[]), –> 737 metadata=obj._tracking_metadata) 738 # pylint:enable=protected-access 739 proto.user_object.CopyFrom(registered_type_proto)
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in _tracking_metadata(self) 2727 @property 2728 def _tracking_metadata(self): -> 2729 return self._trackable_saved_model_saver.tracking_metadata 2730 2731 def _list_extra_dependencies_for_serialization(self, serialization_cache):
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in tracking_metadata(self) 52 # TODO(kathywu): check that serialized JSON can be loaded (e.g., if an 53 # object is in the python property) —> 54 return json_utils.Encoder().encode(self.python_properties) 55 56 def list_extra_dependencies_for_serialization(self, serialization_cache):
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/json_utils.py in encode(self, obj) 42 43 def encode(self, obj): —> 44 return super(Encoder, self).encode(_encode_tuple(obj)) 45 46
/usr/local/opt/pyenv/versions/3.6.7/lib/python3.6/json/encoder.py in encode(self, o) 197 # exceptions aren’t as detailed. The list call should be roughly 198 # equivalent to the PySequence_Fast that ‘’.join() would do. –> 199 chunks = self.iterencode(o, _one_shot=True) 200 if not isinstance(chunks, (list, tuple)): 201 chunks = list(chunks)
/usr/local/opt/pyenv/versions/3.6.7/lib/python3.6/json/encoder.py in iterencode(self, o, _one_shot) 255 self.key_separator, self.item_separator, self.sort_keys, 256 self.skipkeys, _one_shot) –> 257 return _iterencode(o, 0) 258 259 def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/keras/saving/saved_model/json_utils.py in default(self, obj) 39 items = obj.as_list() if obj.rank is not None else None 40 return {‘class_name’: ‘TensorShape’, ‘items’: items} —> 41 return serialization.get_json_type(obj) 42 43 def encode(self, obj):
~/huggingface/transformers/env/lib/python3.6/site-packages/tensorflow/python/util/serialization.py in get_json_type(obj) 74 return obj.wrapped 75 —> 76 raise TypeError(‘Not JSON Serializable:’, obj)
TypeError: (‘Not JSON Serializable:’, RobertaConfig { “_num_labels”: 2, “architectures”: [ “RobertaForMaskedLM” ], “attention_probs_dropout_prob”: 0.1, “bad_words_ids”: null, “bos_token_id”: 0, “decoder_start_token_id”: null, “do_sample”: false, “early_stopping”: false, “eos_token_id”: 2, “finetuning_task”: null, “hidden_act”: “gelu”, “hidden_dropout_prob”: 0.1, “hidden_size”: 768, “id2label”: { “0”: “LABEL_0”, “1”: “LABEL_1” }, “initializer_range”: 0.02, “intermediate_size”: 3072, “is_decoder”: false, “is_encoder_decoder”: false, “label2id”: { “LABEL_0”: 0, “LABEL_1”: 1 }, “layer_norm_eps”: 1e-05, “length_penalty”: 1.0, “max_length”: 20, “max_position_embeddings”: 514, “min_length”: 0, “model_type”: “roberta”, “no_repeat_ngram_size”: 0, “num_attention_heads”: 12, “num_beams”: 1, “num_hidden_layers”: 12, “num_return_sequences”: 1, “output_attentions”: false, “output_hidden_states”: false, “output_past”: true, “pad_token_id”: 1, “prefix”: null, “pruned_heads”: {}, “repetition_penalty”: 1.0, “task_specific_params”: null, “temperature”: 1.0, “top_k”: 50, “top_p”: 1.0, “torchscript”: false, “type_vocab_size”: 1, “use_bfloat16”: false, “vocab_size”: 50265 } )
Expected behavior
There should be no error when saving/serializing the TF Keras Model for Roberta. I do not see any error with Bert or Albert.
Environment info
transformers
version: 2.7.0- Platform: Darwin-19.2.0-x86_64-i386-64bit
- Python version: 3.6.7
- PyTorch version (GPU?): 1.4.0 (False)
- Tensorflow version (GPU?): 2.2.0-rc1 (False)
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
I also see the same issue with TF 2.1.0.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:12 (6 by maintainers)
Top GitHub Comments
You have to install transformers from the master branch. The fix has not been released yet.
Please open another issue with a code snippet to make us able to reproduce your problem.