How to freeze the model?
See original GitHub issueI have trained my data with this model, and there are some .ckpt files, but I want to generate the .pb file for the mobile. But I don’t know the output and the input node name, I just print some node name when I use demo.py to test images. the output is below:
Placeholder Placeholder_1 Placeholder_2 ....... vgg_16/cls_score/weights/Initializer/random_normal/shape vgg_16/cls_score/weights/Initializer/random_normal/mean vgg_16/cls_score/weights/Initializer/random_normal/stddev vgg_16/cls_score/weights/Initializer/random_normal/RandomStandardNormal vgg_16/cls_score/weights/Initializer/random_normal/mul vgg_16/cls_score/weights/Initializer/random_normal vgg_16/cls_score/weights vgg_16/cls_score/weights/Assign vgg_16/cls_score/weights/read vgg_16_3/cls_score/kernel/Regularizer/l2_regularizer/scale vgg_16_3/cls_score/kernel/Regularizer/l2_regularizer/L2Loss vgg_16_3/cls_score/kernel/Regularizer/l2_regularizer vgg_16/cls_score/biases/Initializer/Const vgg_16/cls_score/biases vgg_16/cls_score/biases/Assign vgg_16/cls_score/biases/read vgg_16_3/cls_score/MatMul vgg_16_3/cls_score/BiasAdd vgg_16_3/cls_prob vgg_16_3/cls_pred/dimension vgg_16_3/cls_pred vgg_16/bbox_pred/weights/Initializer/random_normal/shape vgg_16/bbox_pred/weights/Initializer/random_normal/mean vgg_16/bbox_pred/weights/Initializer/random_normal/stddev vgg_16/bbox_pred/weights/Initializer/random_normal/RandomStandardNormal vgg_16/bbox_pred/weights/Initializer/random_normal/mul vgg_16/bbox_pred/weights/Initializer/random_normal vgg_16/bbox_pred/weights vgg_16/bbox_pred/weights/Assign vgg_16/bbox_pred/weights/read vgg_16_3/bbox_pred/kernel/Regularizer/l2_regularizer/scale vgg_16_3/bbox_pred/kernel/Regularizer/l2_regularizer/L2Loss vgg_16_3/bbox_pred/kernel/Regularizer/l2_regularizer vgg_16/bbox_pred/biases/Initializer/Const vgg_16/bbox_pred/biases vgg_16/bbox_pred/biases/Assign vgg_16/bbox_pred/biases/read vgg_16_3/bbox_pred/MatMul vgg_16_3/bbox_pred/BiasAdd
I guess the node vgg_16_3/cls_prob
and some others like this, I generate .pb file using this output node name, But when I restore this .pb file and feed the image to that graph using below code:
` sess = tf.Session()
with gfile.FastGFile(pb_file_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
sess.run(tf.global_variables_initializer())
image = sess.graph.get_tensor_by_name('Placeholder:0')
image_info = sess.graph.get_tensor_by_name('Placeholder_1:0')
gt = sess.graph.get_tensor_by_name("Placeholder_2:0")
score = sess.graph.get_tensor_by_name('SCORE/vgg_16_3/cls_prob/cls_prob/scores:0')
bbox = sess.graph.get_tensor_by_name('SCORE/vgg_16_3/bbox_pred/BiasAdd/bbox_pred/scores:0')
rand_array = np.random.rand(1024, 5)
x_c = tf.constant(rand_array, dtype=tf.float32)
_, scores, bbox_pred, rois = sess.run([score,bbox], feed_dict={image: blobs["data"], image_info: blobs["im_info"],gt: x_c})
`
There are some errors below:
File "/home/chenxingli/anaconda3/envs/python2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 948, in _run raise TypeError('The value of a feed cannot be a tf.Tensor object. ' TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.
please help me with this error, I have tried my best to do this. Thank you so much!
Issue Analytics
- State:
- Created 5 years ago
- Comments:5
Top GitHub Comments
I solved this problem by using the demo.py file to generate the .pb file. In demo.py , it uses the .ckpt file to test result, you can add follwing code to convert .ckpt to .pb.
graph = tf.get_default_graph() input_graph_def = graph.as_graph_def() output_graph = "frozen_model.pb" output_node_names = "vgg16_3/cls_prob,vgg_16_3/bbox_pred/BiasAdd,vgg_16_1/rois/concat,vgg_16_3/cls_score/BiasAdd" output_graph_def = graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names.split(",")) with tf.gfile.GFile(output_graph,"wb") as f: f.write(output_graph_def.SerializeToString())
BTW you need to find the nodes name you need, and in my code, i reuse the .pb file to test whether or not it is right, and it will show more box in one object, so I add following code to correct it:
stds = np.tile(np.array(cfg.TRAIN.BBOX_NORMALIZE_STDS),(2)) means - np.tile(np.array(cfg.TRAIN.BBOX_NORMALIZE_MEANS),(2)) bbox_pred *= stds bbox_pred += means
that’s my solution.碰到了同样的问题,重新编译了.so文件仍没有解决,想问下你后面是怎么解决了这个问题?