question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Is it possible to load tensorflow .pb file into Keras as weight for model?

See original GitHub issue

I have followed the retraining on Tensorflow example for my specific classification task and have a grad-CAM visualization written code in Keras. For instance

Usually, I do load pre-train weights such as vgg16 or inception-v3 in .h5 format and works very well on my grad-CAM work. The problem is the retrained_graph.pb from retraining process by Tensorflow and I have no idea if there are any workaround like

  • mapping .pb file to .h5?

  • or do Keras have any interface to load .pb file in the same manner with loading .h5 file?

Note: I use Tensorflow as the backend

Please advise

Issue Analytics

  • State:closed
  • Created 6 years ago
  • Reactions:35
  • Comments:21 (1 by maintainers)

github_iconTop GitHub Comments

14reactions
stale[bot]commented, Sep 28, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

2reactions
guoxiaolucommented, Jul 4, 2019

You should know the graph defination of your pb and copy all weights to each keras layer. Seems the code format has some problems…

import tensorflow as tf from tensorflow.python.platform import gfile from keras.applications.resnet50 import ResNet50 from keras.layers import Dense, GlobalAveragePooling2D, Convolution2D, BatchNormalization from keras.models import Model from tensorflow.python.framework import tensor_util

GRAPH_PB_PATH = xxx.pb’ #path to your .pb file with tf.Session() as sess: print(“load graph”) with gfile.FastGFile(GRAPH_PB_PATH,‘rb’) as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def, name=‘’) graph_nodes=[n for n in graph_def.node]

wts = [n for n in graph_nodes if n.op==‘Const’]

weight_dict = {} for i, n in enumerate(wts): weight_dict[n.name] = i

model = ResNet50(input_shape=(224, 224, 3), include_top=True) model.summary()

for layer in model.layers: layer_weight = layer.get_weights() name = layer.name if len(layer_weight) == 0: continue if isinstance(layer, Convolution2D): kname = name + ‘/kernel’ bname = name + ‘/bias’ if kname not in weight_dict or bname not in weight_dict: print kname, bname else: weights = [] idx = weight_dict[kname] wtensor = wts[idx].attr[‘value’].tensor weight = tensor_util.MakeNdarray(wtensor) weights.append(weight)

        idx = weight_dict[bname]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)
        layer.set_weights(weights)
        continue
if isinstance(layer, BatchNormalization):
    beta_name = name + '/beta'
    gamma_name = name + '/gamma'
    mmean_name = name + '/moving_mean'
    mvar_name = name + '/moving_variance'

    if beta_name not in weight_dict or gamma_name not in weight_dict or\
            mmean_name not in weight_dict or mvar_name not in weight_dict:
        print beta_name, gamma_name, mmean_name, mvar_name
    else:
        weights = []
        idx = weight_dict[gamma_name]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)

        idx = weight_dict[beta_name]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)

        idx = weight_dict[mmean_name]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)

        idx = weight_dict[mvar_name]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)
        layer.set_weights(weights)
        continue
if isinstance(layer, Dense):
    kname = name + '/kernel'
    bname = name + '/bias'
    if kname not in weight_dict or bname not in weight_dict:
        print kname, bname
    else:
        weights = []
        idx = weight_dict[kname]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)

        idx = weight_dict[bname]
        wtensor = wts[idx].attr['value'].tensor
        weight = tensor_util.MakeNdarray(wtensor)
        weights.append(weight)
        layer.set_weights(weights)
        continue
print name

model.save(‘tmp.h5’)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Can I get a Keras model from a .pb file? - Stack Overflow
I loaded the .pb file in tf and generate model cofigs and layer weights. import tensorflow as tf from tensorflow.python.platform import ...
Read more >
Save and load Keras models | TensorFlow Core
When saving the model and its layers, the SavedModel format stores the class name, call function, losses, and weights (and the config, if ......
Read more >
Save and load models | TensorFlow Core
To save weights manually, use tf.keras.Model.save_weights . By default, tf.keras —and the Model.save_weights method in particular—uses the ...
Read more >
Save and load models | TensorFlow.js
Loading these models are covered in the following two tutorials: ... A binary file carrying the weight values named model.weights.bin .
Read more >
Save and load a model using a distribution strategy
For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found