question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Is it possible to save trained model as TF saved_model format? If so how?

See original GitHub issue

This is a question, but I don’t know how to add the question tag.

My question is about exporting my model for use in other systems. The ultimate goal is to get it into ONNX format. I intend to achieve this using tf2onnx. However, the preferred input format for tf2onnx is a tensorflow saved_model format. Therefore I would like to export to this format.

Is this possible, and if so, how?

I tried the following:

import gym
import tensorflow as tf

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO1


#with tf.Graph().as_default():
#  with tf.Session() as sess:

env = gym.make('CartPole-v1')
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])

model = PPO1(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=1000)

#tf.global_variables_initializer().run(session=model.sess)

init = tf.global_variables_initializer()

model.sess.run(init)

tf.train.Saver ()

saver.save(model.sess, '/my/redacted/save/dir')

But this fails with the following error

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
    302         self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 303             fetch, allow_tensor=True, allow_operation=True))
    304       except TypeError as e:

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
   3795     with self._lock:
-> 3796       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3797 

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
   3879       if obj.graph is not self:
-> 3880         raise ValueError("Operation %s is not an element of this graph." % obj)
   3881       return obj

ValueError: Operation name: "init"
op: "NoOp"
 is not an element of this graph.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_trained_model_to_ONNX_example.py in <module>()
     22 init = tf.global_variables_initializer()
     23 
---> 24 model.sess.run(init)
     25 
     26 tf.train.Saver ()

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    948     try:
    949       result = self._run(None, fetches, feed_dict, options_ptr,
--> 950                          run_metadata_ptr)
    951       if run_metadata:
    952         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1156     # Create a fetch handler to take care of the structure of fetches.
   1157     fetch_handler = _FetchHandler(
-> 1158         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1159 
   1160     # Run request and get response.

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
    472     """
    473     with graph.as_default():
--> 474       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    475     self._fetches = []
    476     self._targets = []

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    272         if isinstance(fetch, tensor_type):
    273           fetches, contraction_fn = fetch_fn(fetch)
--> 274           return _ElementFetchMapper(fetches, contraction_fn)
    275     # Did not find anything.
    276     raise TypeError('Fetch argument %r has invalid type %r' % (fetch,

/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
    308       except ValueError as e:
    309         raise ValueError('Fetch argument %r cannot be interpreted as a '
--> 310                          'Tensor. (%s)' % (fetch, str(e)))
    311       except KeyError as e:
    312         raise ValueError('Fetch argument %r cannot be interpreted as a '

ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor. (Operation name: "init"
op: "NoOp"
 is not an element of this graph.)

I’ve tried various things, but I’m floundering about a bit. Can anyone confirm it’s even possible and if so maybe give some pointers?

I think this would be generally useful for the community.

System Info Describe the characteristic of your environment:

Mint Linux 19.3

Everything was installed via pip3

Python version: 3.6.9 Tensorflow version: 1.14 Stable Baselines version: 2.9.0 tf2onnx version: 1.5.4

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:7

github_iconTop GitHub Comments

1reaction
crobarcrocommented, Feb 13, 2020

I did look at this, yes, but while this has some pointers it doesn’t quite have a full example anywhere. Actually I haven’t yet tried the simple_save method. I will do this and report back.

If I get it to work I’d be happy to update the docs, even just for my own records.

0reactions
crobarcrocommented, Feb 14, 2020

Thanks, actually I had seen that issue and made the change after @Miffyli had pointed it out (actually I had skimmed over both issues previously, but missed the crucial details).

Anyway, I tried the suggested names, see the script below, which just attempts to save the model and then load it again with tensorflow:

import shutil, os
import gym
import tensorflow as tf

from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2


#with tf.Graph().as_default():
#  with tf.Session() as sess:

env = gym.make('CartPole-v1')
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])

model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=1000)


containing_dir = os.path.dirname(os.path.realpath(__file__))

path = os.path.join(containing_dir, 'export_model_test')

shutil.rmtree (path, ignore_errors=True)

os.mkdir (path)

##########          using simple_save          #############
inputs_dict = {
               #"obs": model.policy_tf.obs_phmodel.policy.obs_ph
               #"obs": model.policy.obs_ph
               "obs": model.act_model.obs_ph
             }

outputs_dict = {
   #"action": model.policy.action_ph
   #"action": model.action_ph
   "action": model.act_model._policy_proba
}

tf.saved_model.simple_save(
   model.sess, path, inputs_dict, outputs_dict
)
#############################################################

##########          using tf.saved_model.loader         #############
# init = tf.global_variables_initializer()
# model.sess.run(init)
# saver = tf.train.Saver()
# saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
# tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)
#####################################################################

# ##########          using import/export meta_graph          #############
# meta_file = os.path.join(path, 'saved_model.meta')
# meta_graph_def = tf.train.export_meta_graph( filename = meta_file,
#                                              graph=model.graph,
#                                              graph_def=model.graph.as_graph_def() )
############################################################################

# I think I need to close the session to free any resources
model.sess.close ()

##########          using tf.saved_model.loader         #############
with tf.Session() as sess:
   tf.saved_model.loader.load(sess, [tf.saved_model.SERVING], path)

   graph = tf.get_default_graph()

   print(graph.get_operations())
#######################################################################

# ##########          using tf.saved_model.loader             #############
# restored_graph = tf.Graph()
# with restored_graph.as_default():
#     with tf.Session() as sess:
#         tf.saved_model.loader.load(
#             sess,
#             [tf.saved_model.SERVING],
#             path,
#         )
#         obs_placeholder = restored_graph.get_tensor_by_name('obs:0')
#
#         sess.run(prediction, feed_dict={
#             obs_placeholder: some_value,
#         })
############################################################################


# ##########          using import_meta_graph          #############
# with tf.Session() as sess:
#     new_saver = tf.train.import_meta_graph(meta_file)
#     new_saver.restore(sess, meta_file)
#
#     # sess.run(prediction, feed_dict={
#     #          obs_placeholder: some_value,
#     #      })
######################################################################

With this, however, I get the following output.

INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_model_test/saved_model.pb
WARNING:tensorflow:From /home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_trained_model_to_ONNX_example.py:67: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored.
[]

There are two comented out sections which represent alternative methods of saving/loading which I have found. If I try to save using

if I do this:

init = tf.global_variables_initializer()
model.sess.run(init)
saver = tf.train.Saver()
saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)

I get

ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor. (Operation name: "init"
op: "NoOp"
 is not an element of this graph.)

If I do this:

#init = tf.global_variables_initializer()
# model.sess.run(init)
saver = tf.train.Saver()
saver.save(model.sess, os.path.join(path, 'tensorflowModel.ckpt'))
tf.train.write_graph(model.sess.graph.as_graph_def(), path, 'tensorflowModel.pbtxt', as_text=True)

I get this:

ValueError: No variables to save

Exporting the meta graph seems to work, but when I load it there doesn’t seem to be anything in the graph.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Using the SavedModel format | TensorFlow Core
A SavedModel contains a complete TensorFlow program, including trained parameters (i.e, tf.Variable s) and computation. It does not require the original ...
Read more >
Saving and Loading a TensorFlow model using the ... - Medium
The SavedModel API allows you to save a trained model into a format that can be easily loaded in Python, Java, (soon JavaScript), ......
Read more >
Save and Load Models with TensorFlow | Don't Repeat Yourself
When you save a model you can save it after training or save checkpoints at regular intervals during training. We will cover both...
Read more >
Include Training Operations in Saved Models with Tensorflow 2
Most Tensorflow documentation and tutorials show how to train a model in python and save it in the SavedModel format for prediction in ......
Read more >
How to save/restore a model after training? - Stack Overflow
If you restore to continue to train, just use the Saver checkpoints. If you save the model to do reference, just the tensorflow...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found