question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

model.save() does not save keras model that includes DIstillBert layer

See original GitHub issue

🐛 Bug

Information

I am trying to build a Keras Sequential model, where, I use DistillBERT as a non-trainable embedding layer. The model complies and fits well, even predict method works. But when I want to save it using model.save(model.h5), It fails and shows the following error:

> ---------------------------------------------------------------------------
> NotImplementedError                       Traceback (most recent call last)
> <ipython-input-269-557c9cec7497> in <module>
> ----> 1 model.get_config()
> 
> /usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/network.py in get_config(self)
>     966     if not self._is_graph_network:
>     967       raise NotImplementedError
> --> 968     return copy.deepcopy(get_network_config(self))
>     969 
>     970   @classmethod
> 
> /usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
>    2117           filtered_inbound_nodes.append(node_data)
>    2118 
> -> 2119     layer_config = serialize_layer_fn(layer)
>    2120     layer_config['name'] = layer.name
>    2121     layer_config['inbound_nodes'] = filtered_inbound_nodes
> 
> /usr/local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
>     273         return serialize_keras_class_and_config(
>     274             name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
> --> 275       raise e
>     276     serialization_config = {}
>     277     for key, item in config.items():
> 
> /usr/local/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
>     268     name = get_registered_name(instance.__class__)
>     269     try:
> --> 270       config = instance.get_config()
>     271     except NotImplementedError as e:
>     272       if _SKIP_FAILED_SERIALIZATION:
> 
> /usr/local/lib/python3.7/site-packages/tensorflow/python/keras/engine/network.py in get_config(self)
>     965   def get_config(self):
>     966     if not self._is_graph_network:
> --> 967       raise NotImplementedError
>     968     return copy.deepcopy(get_network_config(self))
>     969 
> 
> NotImplementedError: 

The language I am using the model in English.

The problem arises when using my own modified scripts: (give details below)

from transformers import DistilBertConfig, TFDistilBertModel, DistilBertTokenizer
max_len = 8
distil_bert = 'distilbert-base-uncased'
config = DistilBertConfig(dropout=0.2, attention_dropout=0.2)
config.output_hidden_states = False
transformer_model = TFDistilBertModel.from_pretrained(distil_bert, config = config)

input_word_ids = tf.keras.layers.Input(shape=(max_len,), dtype = tf.int32, name = "input_word_ids")
distill_output =  transformer_model(input_word_ids)[0]

cls_out = tf.keras.layers.Lambda(lambda seq: seq[:, 0, :])(distill_output)
X = tf.keras.layers.BatchNormalization()(cls_out)
X = tf.keras.layers.Dense(256, activation='relu')(X)
X = tf.keras.layers.Dropout(0.2)(X)

X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dense(128, activation='relu')(X)
X = tf.keras.layers.Dropout(0.2)(X)

X = tf.keras.layers.BatchNormalization()(X)
X = tf.keras.layers.Dense(64, activation='relu')(X)
X = tf.keras.layers.Dropout(0.2)(X)

X = tf.keras.layers.Dense(2)(X)
model = tf.keras.Model(inputs=input_word_ids, outputs=X)

for layer in model.layers[:3]:
    layer.trainable = False

The tasks I am working on is my own dataset.

To reproduce

Steps to reproduce the behavior:

  1. Run the above code
  2. You will get the error when saving the model as
model.save('model.h5')

You can get the same error if you try:

model.get_config()

An interesting observation: if you save the model without specifying “.h5” like

model.save('./model')

it saves the model as TensorFlow saved_model format and creates folders (assets (empty), variables, and some index files). But if you try to load the model, it produces different errors related to the DistillBert/Bert. It may be due to some naming inconsistency (input_ids vs. inputs, see below) inside the DistillBert model.


new_model = tf.keras.models.load_model('./model)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.7/site-packages/tensorflow/python/util/nest.py in assert_same_structure(nest1, nest2, check_types, expand_composites)
    377     _pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
--> 378                                       expand_composites)
    379   except (ValueError, TypeError) as e:

ValueError: The two structures don't have the same nested structure.

First structure: type=dict str={'input_ids': TensorSpec(shape=(None, 5), dtype=tf.int32, name='input_ids')}

Second structure: type=TensorSpec str=TensorSpec(shape=(None, 8), dtype=tf.int32, name='inputs')

More specifically: Substructure "type=dict str={'input_ids': TensorSpec(shape=(None, 5), dtype=tf.int32, name='input_ids')}" is a sequence, while substructure "type=TensorSpec str=TensorSpec(shape=(None, 8), dtype=tf.int32, name='inputs')" is not

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-229-b46ed71fd9ad> in <module>
----> 1 new_model = tf.keras.models.load_model(keras_model_path)

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile)
    188     if isinstance(filepath, six.string_types):
    189       loader_impl.parse_saved_model(filepath)
--> 190       return saved_model_load.load(filepath, compile)
    191 
    192   raise IOError(

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in load(path, compile)
    114   # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
    115   # TODO(kathywu): Add code to load from objects that contain all endpoints
--> 116   model = tf_load.load_internal(path, loader_cls=KerasObjectLoader)
    117 
    118   # pylint: disable=protected-access

/usr/local/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in load_internal(export_dir, tags, loader_cls)
    602       loader = loader_cls(object_graph_proto,
    603                           saved_model_proto,
--> 604                           export_dir)
    605       root = loader.get(0)
    606       if isinstance(loader, Loader):

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in __init__(self, *args, **kwargs)
    186     self._models_to_reconstruct = []
    187 
--> 188     super(KerasObjectLoader, self).__init__(*args, **kwargs)
    189 
    190     # Now that the node object has been fully loaded, and the checkpoint has

/usr/local/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py in __init__(self, object_graph_proto, saved_model_proto, export_dir)
    121       self._concrete_functions[name] = _WrapperFunction(concrete_function)
    122 
--> 123     self._load_all()
    124     self._restore_checkpoint()
    125 

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_all(self)
    213 
    214     # Finish setting up layers and models. See function docstring for more info.
--> 215     self._finalize_objects()
    216 
    217   @property

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _finalize_objects(self)
    504         layers_revived_from_saved_model.append(node)
    505 
--> 506     _finalize_saved_model_layers(layers_revived_from_saved_model)
    507     _finalize_config_layers(layers_revived_from_config)
    508 

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _finalize_saved_model_layers(layers)
    675       call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
    676       if call_fn.input_signature is None:
--> 677         inputs = infer_inputs_from_restored_call_function(call_fn)
    678       else:
    679         inputs = call_fn.input_signature[0]

/usr/local/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in infer_inputs_from_restored_call_function(fn)
    919   for concrete in fn.concrete_functions[1:]:
    920     spec2 = concrete.structured_input_signature[0][0]
--> 921     spec = nest.map_structure(common_spec, spec, spec2)
    922   return spec
    923 

/usr/local/lib/python3.7/site-packages/tensorflow/python/util/nest.py in map_structure(func, *structure, **kwargs)
    609   for other in structure[1:]:
    610     assert_same_structure(structure[0], other, check_types=check_types,
--> 611                           expand_composites=expand_composites)
    612 
    613   flat_structure = [flatten(s, expand_composites) for s in structure]

/usr/local/lib/python3.7/site-packages/tensorflow/python/util/nest.py in assert_same_structure(nest1, nest2, check_types, expand_composites)
    383                   "Entire first structure:\n%s\n"
    384                   "Entire second structure:\n%s"
--> 385                   % (str(e), str1, str2))
    386 
    387 

ValueError: The two structures don't have the same nested structure.

First structure: type=dict str={'input_ids': TensorSpec(shape=(None, 5), dtype=tf.int32, name='input_ids')}

Second structure: type=TensorSpec str=TensorSpec(shape=(None, 8), dtype=tf.int32, name='inputs')

More specifically: Substructure "type=dict str={'input_ids': TensorSpec(shape=(None, 5), dtype=tf.int32, name='input_ids')}" is a sequence, while substructure "type=TensorSpec str=TensorSpec(shape=(None, 8), dtype=tf.int32, name='inputs')" is not
Entire first structure:
{'input_ids': .}
Entire second structure:
.

Expected behavior

I expect to have a normal saving and loading of the model.

Environment info

  • transformers version: 2.9.1
  • Platform:
  • Python version: 3.7.6
  • Tensorflow version (CPU): 2.2.0
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:13
  • Comments:21 (11 by maintainers)

github_iconTop GitHub Comments

16reactions
pdegnercommented, Jul 31, 2020

I had this exact error. I got around it by saving the weights and the code that creates the model. After training your model, runmodel.save_weights('path/savefile'). Note there is no .h5 on it.

When you want to reuse the model later, run your code until model.compile(). Then, model.load_weights('path/savefile').

13reactions
sajib-kumarcommented, Jun 17, 2020

Same issue

Read more comments on GitHub >

github_iconTop Results From Across the Web

Save and load Keras models | TensorFlow Core
It is the default when you use model.save() . ... graph of custom objects such as custom layers is not included in the...
Read more >
How to save bert or distilbert model? - Hugging Face Forums
Hi ALL! I'm having issues loading my trained distilbert model, cannot figure out a way to resolve the issue, when i try to...
Read more >
Isues with saving and loading tensorflow model which uses ...
I am using the latest Huggingface transformers tensorflow keras version. The idea is to extract features using distilbert and then run the ...
Read more >
Model saving & serialization APIs - Keras
The traced functions allow the SavedModel format to save and load custom layers without the original class definition. You can choose to not...
Read more >
Compiling and Deploying Pretrained HuggingFace Pipelines ...
This save method prefers to work on a flat input/output lists and does not work on dictionary input/output - which is what the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found