question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to use model.save() in tf2 when using TFBertModel

See original GitHub issue

tensorflow==2.3.1

transformers==4.2.1

My model define as:

import tensorflow as tf
from tensorflow.keras import Model 
from tensorflow.keras.layers import *
from transformers import TFAutoModel

input_ids = Input(shape=(3000), name='INPUT_input_ids', dtype=tf.int32)
input_mask = Input(shape=(3000), name='INPUT_input_mask', dtype=tf.int32)
segment_ids = Input(shape=(3000), name='INPUT_segment_ids', dtype=tf.int32)
passage_mask = Input(shape=(10), name='INPUT_passage_mask', dtype=tf.int32)
input_ids_reshape = K.reshape(input_ids,(-1, 300))
input_mask_reshape = K.reshape(input_mask,(-1, 300))
segment_ids_reshape = K.reshape(segment_ids,(-1, 300))
transformer = TFAutoModel.from_pretrained('hfl/chinese-roberta-wwm-ext', from_pt=False)
transformer_output = transformer([input_ids_reshape, input_mask_reshape, segment_ids_reshape])[0]
......
model = Model(
    inputs  = [input_ids, input_mask, segment_ids, passage_mask], 
    outputs = [start_prob, end_prob]
)

I try to save model in this way:

model.save(path)

but I got error

/lib/python3.6/site-packages/transformers/modeling_tf_utils.py in input_processing(func, config, input_ids, **kwargs)
    364                     output[tensor_name] = input
    365                 else:
--> 366                     output[parameter_names[i]] = input
    367             elif isinstance(input, allowed_types) or input is None:
    368                 output[parameter_names[i]] = input

IndexError: list index out of range

model.predict() and model.save_weights() is working.

How to use model.save() with huggingface-transformers? OR How to write model with huggingface-transformers? I just want to use transformers as a keras layer in my model.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

3reactions
ffy2017commented, Oct 29, 2021

I have the same problem? how to fix?

0reactions
Crazy-LittleBoycommented, Oct 1, 2021

嗨@leisurehippo,并非我们所有的模型都能很好地与model.save(). 如果您想获得 SavedModel 输出,请尝试model.save_pretrained()使用saved_model=True. 您可以在此处查看有关该方法的更多信息。 如果您在使用时仍然遇到同样的问题save_pretrained,请告诉我,我会尝试重现该问题。

我尝试以这种方式使用 SavedModelBuilder 类进行保存:

signature = tf.compat.v1.saved_model.predict_signature_def(
    inputs={t.name: t for t in model.inputs}, 
    outputs={t.name: t for t in model.output}
)
builder = tf.compat.v1.saved_model.Builder(export_path)
builder.add_meta_graph_and_variables(
    sess=tf.compat.v1.keras.backend.get_session(),
    tags=[tf.compat.v1.saved_model.tag_constants.SERVING],
    signature_def_map = {'predict':signature},)
builder.save()

但是好像session不是我模型里的那个,.pb文件很小

hi, do you fix this problem? and how? thx.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Save and load a model using a distribution strategy
This tutorial demonstrates how you can save and load models in a SavedModel format with tf.distribute.Strategy during or after training.
Read more >
Trouble saving tf.keras model with Bert (huggingface) classifier
I am aware that huggingface provides a model.save_pretrained() method for TFBertModel, but I prefer to wrap it in tf.keras.Model as I plan to ......
Read more >
François Chollet on Twitter: "Exciting -- you can now push any ...
Exciting -- you can now push any Keras model to the HuggingFace Hub in ... How to use model.save() in tf2 when using...
Read more >
Working with Hugging Face Transformers and TF 2.0
2.4 Inference. As the model is based on tf.keras model API, we can use Keras' same commonly used method of model.predict(). We ...
Read more >
BERT in keras (tensorflow 2.0) using tfhub/huggingface
For tf 2.0, hub.module() will not work. we need to use ... If you are using tfhub for bert implementation, some of them...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found