load_model() fails in Flask request context only
See original GitHub issue# predict.py
import os
import random
import re
import pickle
import utils
import shutil
import requests
import keras
from keras.models import load_model
from keras import backend as K
def load_classification_model(company_id):
model_dir = os.path.realpath('./models/company_' + str(company_id))
model_dir += '/' + os.listdir(model_dir)[-1]
model_path = model_dir + '/model.h5'
labels_path = model_dir + '/labels.pickle'
print 'Loading model ' + model_path + ' ...'
model = load_model(model_path)
graph = K.function([model.layers[0].input, K.learning_phase()], [model.layers[-1].output])
class_names = pickle.load(open(labels_path, 'rb'))
return graph, class_names
def predict_image(company_id, url = None, part = None, inspection = None):
model_graph, class_names = load_classification_model(company_id)
# ...load image, preprocess, predict...
import predict
and predict.predict_image(...)
works perfectly in REPL, but fails when called inside of a Flask request context. The error traceback:
# curl ml:5000/predict?company=1&image_url=<image_url>
[top of traceback omitted for brevity]
File "/code/app.py", line 22, in predict_route
result = predict_image(company_id, url=image_url)
File "/code/predict.py", line 34, in predict_image
model_graph, class_names = load_classification_model(company_id)
File "/code/predict.py", line 21, in load_classification_model
model = load_model(model_path)
File "/usr/local/lib/python2.7/site-packages/keras/models.py", line 242, in load_model
topology.load_weights_from_hdf5_group(f['model_weights'], model.layers)
File "/usr/local/lib/python2.7/site-packages/keras/engine/topology.py", line 3095, in load_weights_from_hdf5_group
K.batch_set_value(weight_value_tuples)
File "/usr/local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2193, in batch_set_value
get_session().run(assign_ops, feed_dict=feed_dict)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1071, in _run
+ e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(2048, 64), dtype=float32) is not an element of this graph.
Seems like a tensorflow session collision maybe?
Possibly of relevance is that inside utils there is another load_model() for a ‘pre-processing’ model (frozen pre-trained model). The reason for this pre-model is that I am training many different (small) top classifier models to stack on top of the (large) pre-trained model. All will use the same large base model, just with different small classifier models on top.
What I don’t understand is why this would only fail inside a Flask request context, and not with REPL predict.predict_image(...)
Issue Analytics
- State:
- Created 6 years ago
- Reactions:4
- Comments:7
Top Results From Across the Web
Keras load_model() fails in Flask request context only
It seems there is a bug in Keras when using tensorflow graph cross threads. To fix it: # Right after loading or constructing...
Read more >The Request Context — Flask Documentation (2.2.x)
When the Flask application handles a request, it creates a Request object based on the environment it received from the WSGI server. Because...
Read more >Error when using TFLite interpreter in Flask - TensorFlow Forum
interpreter in the form of a NumPy array or slice. Be sure to only hold the function returned from tensor() if you are...
Read more >Deep Learning in Production: A Flask Approach - Medium
As with the flask app we built above, it will return a result only for the first incoming request and will fail with...
Read more >MLflow Models — MLflow 2.0.1 documentation
load_model() . This loaded PyFunc model can only be scored with DataFrame input. When a model with the spark flavor is loaded as...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I had the same issue. The following resolved it for me:
from keras import backend as K
K.clear_session()
You shouldn’t load models in request handlers (with Flask, concurrency is best handled outside of it 😃