how can "Estimater.predict" do a real-time prediction ?
See original GitHub issuei want to use bert for sentiment classification mission, i fine-tuned bert on a dataset and get an available model, and then, i found it is very slow to predict one sample, someone said the reason is that Estimater.predict will reload graph on each call, which almost spent five seconds!!! god!!
some people say using tf.data.Dataset.from_generator() is a resolution, but i still cant succeed after trying many times
here is part of my code:
def input_gen(self):
while True:
text = "这是一个测试" # input()
examples = [InputExample(guid=uuid.uuid4(), text_a=text, text_b=None, label="0")]
features = BertSentiment.examples_to_features(examples, self.label_list, self.max_seq_length, self.tokenizer)
all_input_ids = []
all_input_mask = []
all_segment_ids = []
all_label_ids = []
for feature in features:
all_input_ids.append(feature.input_ids)
all_input_mask.append(feature.input_mask)
all_segment_ids.append(feature.segment_ids)
all_label_ids.append(feature.label_id)
num_examples = len(features)
input_ids = tf.constant(all_input_ids, shape=[num_examples, self.max_seq_length], dtype=tf.int32)
input_mask = tf.constant(all_input_mask, shape=[num_examples, self.max_seq_length], dtype=tf.int32)
segment_ids = tf.constant(all_segment_ids, shape=[num_examples, self.max_seq_length], dtype=tf.int32)
label_ids = tf.constant(all_label_ids, shape=[num_examples], dtype=tf.int32)
print("here1:{}".format(input_ids.shape))
yield input_ids, input_mask, segment_ids, label_ids
def input_fn(self, params):
"""The actual input function."""
d = tf.data.Dataset.from_generator(self.input_gen, output_types=(tf.int32, tf.int32, tf.int32, tf.int32))
d = d.batch(batch_size=params["batch_size"], drop_remainder=False)
iterator = d.make_one_shot_iterator()
input_ids, input_mask, segment_ids, label_ids = iterator.get_next()
return {'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, 'label_ids': label_ids}
def predict(self):
result = self.estimator.predict(input_fn=self.input_fn)
result = list(result)
print(result)
Issue Analytics
- State:
- Created 4 years ago
- Comments:12
Top Results From Across the Web
Estimation, Prediction and Forecasting - Towards Data Science
Estimation implies finding the optimal parameter using historical data whereas prediction uses the data to compute the random value of the ...
Read more >Making Predictions with Regression Analysis - Statistics By Jim
When we use regression to make predictions, our goal is to produce predictions that are both correct on average and close to the...
Read more >How can I predict the value and and estimate the uncertainty ...
In fact, the actual estimate of the new measured value is obtained by evaluating the estimated regression function at the relevant predictor variable...
Read more >4. Regression and Prediction - Practical Statistics for Data ...
The estimate has uncertainty, whereas the true value is fixed. We compute the residuals e ^ i by subtracting the predicted values from...
Read more >What is the difference between estimation and prediction?
An estimator uses data to guess at a parameter while a predictor uses the data to guess at some random value that is...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Here is a simple alternative: Instead of using tf.estimator.predict, you can export your model and use predictor from tensorflow.contrib. This way you avoid loading the graph everytime. It roughly works like this:
First export your model after training:
Then when you want to use it for inference reload your model using predictor:
This should speed your inference up significantly.
Well~,After a day’s hard work, I found the answer: https://hanxiao.github.io/2019/01/02/Serving-Google-BERT-in-Production-using-Tensorflow-and-ZeroMQ/
Thank you very much Dr. Han Xiao