Is it possible to load tensorflow .pb file into Keras as weight for model?
See original GitHub issueI have followed the retraining on Tensorflow example for my specific classification task and have a grad-CAM visualization written code in Keras. For instance
Usually, I do load pre-train weights such as vgg16 or inception-v3 in .h5 format and works very well on my grad-CAM work. The problem is the retrained_graph.pb
from retraining process by Tensorflow and I have no idea if there are any workaround like
-
mapping
.pb
file to.h5
? -
or do Keras have any interface to load
.pb
file in the same manner with loading.h5
file?
Note: I use Tensorflow as the backend
Please advise
Issue Analytics
- State:
- Created 6 years ago
- Reactions:35
- Comments:21 (1 by maintainers)
Top Results From Across the Web
Can I get a Keras model from a .pb file? - Stack Overflow
I loaded the .pb file in tf and generate model cofigs and layer weights. import tensorflow as tf from tensorflow.python.platform import ...
Read more >Save and load Keras models | TensorFlow Core
When saving the model and its layers, the SavedModel format stores the class name, call function, losses, and weights (and the config, if ......
Read more >Save and load models | TensorFlow Core
To save weights manually, use tf.keras.Model.save_weights . By default, tf.keras —and the Model.save_weights method in particular—uses the ...
Read more >Save and load models | TensorFlow.js
Loading these models are covered in the following two tutorials: ... A binary file carrying the weight values named model.weights.bin .
Read more >Save and load a model using a distribution strategy
For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.
You should know the graph defination of your pb and copy all weights to each keras layer. Seems the code format has some problems…
import tensorflow as tf from tensorflow.python.platform import gfile from keras.applications.resnet50 import ResNet50 from keras.layers import Dense, GlobalAveragePooling2D, Convolution2D, BatchNormalization from keras.models import Model from tensorflow.python.framework import tensor_util
GRAPH_PB_PATH = xxx.pb’ #path to your .pb file with tf.Session() as sess: print(“load graph”) with gfile.FastGFile(GRAPH_PB_PATH,‘rb’) as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name=‘’) graph_nodes=[n for n in graph_def.node]
wts = [n for n in graph_nodes if n.op==‘Const’]
weight_dict = {} for i, n in enumerate(wts): weight_dict[n.name] = i
model = ResNet50(input_shape=(224, 224, 3), include_top=True) model.summary()
for layer in model.layers: layer_weight = layer.get_weights() name = layer.name if len(layer_weight) == 0: continue if isinstance(layer, Convolution2D): kname = name + ‘/kernel’ bname = name + ‘/bias’ if kname not in weight_dict or bname not in weight_dict: print kname, bname else: weights = [] idx = weight_dict[kname] wtensor = wts[idx].attr[‘value’].tensor weight = tensor_util.MakeNdarray(wtensor) weights.append(weight)
model.save(‘tmp.h5’)