Is it possible to save trained model as TF saved_model format? If so how?
See original GitHub issueThis is a question, but I don’t know how to add the question tag.
My question is about exporting my model for use in other systems. The ultimate goal is to get it into ONNX format. I intend to achieve this using tf2onnx. However, the preferred input format for tf2onnx is a tensorflow saved_model format. Therefore I would like to export to this format.
Is this possible, and if so, how?
I tried the following:
import gym
import tensorflow as tf
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO1
#with tf.Graph().as_default():
# with tf.Session() as sess:
env = gym.make('CartPole-v1')
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])
model = PPO1(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=1000)
#tf.global_variables_initializer().run(session=model.sess)
init = tf.global_variables_initializer()
model.sess.run(init)
tf.train.Saver ()
saver.save(model.sess, '/my/redacted/save/dir')
But this fails with the following error
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
302 self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 303 fetch, allow_tensor=True, allow_operation=True))
304 except TypeError as e:
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
3795 with self._lock:
-> 3796 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3797
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
3879 if obj.graph is not self:
-> 3880 raise ValueError("Operation %s is not an element of this graph." % obj)
3881 return obj
ValueError: Operation name: "init"
op: "NoOp"
is not an element of this graph.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
/home/rcrozier/src/ceorl_core-refactor-hg/openai_gym/common/export_trained_model_to_ONNX_example.py in <module>()
22 init = tf.global_variables_initializer()
23
---> 24 model.sess.run(init)
25
26 tf.train.Saver ()
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
948 try:
949 result = self._run(None, fetches, feed_dict, options_ptr,
--> 950 run_metadata_ptr)
951 if run_metadata:
952 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1156 # Create a fetch handler to take care of the structure of fetches.
1157 fetch_handler = _FetchHandler(
-> 1158 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1159
1160 # Run request and get response.
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
472 """
473 with graph.as_default():
--> 474 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
475 self._fetches = []
476 self._targets = []
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
272 if isinstance(fetch, tensor_type):
273 fetches, contraction_fn = fetch_fn(fetch)
--> 274 return _ElementFetchMapper(fetches, contraction_fn)
275 # Did not find anything.
276 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
/home/rcrozier/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
308 except ValueError as e:
309 raise ValueError('Fetch argument %r cannot be interpreted as a '
--> 310 'Tensor. (%s)' % (fetch, str(e)))
311 except KeyError as e:
312 raise ValueError('Fetch argument %r cannot be interpreted as a '
ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor. (Operation name: "init"
op: "NoOp"
is not an element of this graph.)
I’ve tried various things, but I’m floundering about a bit. Can anyone confirm it’s even possible and if so maybe give some pointers?
I think this would be generally useful for the community.
System Info Describe the characteristic of your environment:
Mint Linux 19.3
Everything was installed via pip3
Python version: 3.6.9 Tensorflow version: 1.14 Stable Baselines version: 2.9.0 tf2onnx version: 1.5.4
Issue Analytics
- State:
- Created 4 years ago
- Comments:7
Top GitHub Comments
I did look at this, yes, but while this has some pointers it doesn’t quite have a full example anywhere. Actually I haven’t yet tried the
simple_save
method. I will do this and report back.If I get it to work I’d be happy to update the docs, even just for my own records.
Thanks, actually I had seen that issue and made the change after @Miffyli had pointed it out (actually I had skimmed over both issues previously, but missed the crucial details).
Anyway, I tried the suggested names, see the script below, which just attempts to save the model and then load it again with tensorflow:
With this, however, I get the following output.
There are two comented out sections which represent alternative methods of saving/loading which I have found. If I try to save using
if I do this:
I get
If I do this:
I get this:
Exporting the meta graph seems to work, but when I load it there doesn’t seem to be anything in the graph.