[question] How to save a trained PPO2 agent to use in a Java program?
See original GitHub issueI am using a PPO2 agent to train on a custom environment. I use the save
function to store everything in a .pkl
in the callback function, similar to the example from the Colab notebook.
def callback(_locals, _globals):
"""
Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
:param _locals: (dict)
:param _globals: (dict)
"""
global n_steps, best_mean_reward, saving_interval, pickle_dir
# Print stats every X calls
if (n_steps + 1) % saving_interval == 0:
# Evaluate policy training performance
x, y = ts2xy(load_results(log_dir), 'timesteps')
if len(x) > 0:
mean_reward = np.mean(y[-100:])
logger.info("{} timesteps".format(x[-1]))
logger.info(
"Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(best_mean_reward, mean_reward))
# New best model, you could save the agent here
if mean_reward > best_mean_reward:
best_mean_reward = mean_reward
# Example for saving best model
logger.info("Saving new best model")
_locals['self'].save(pickle_dir + 'ppo2_best_model.pkl')
n_steps += 1
# Returning False will stop training early
return True
What I would like to do is extract from the .pkl
file only what is necessary to take an observation and return an action. I would use this data in a Java program to get the action that I need without having to use Python. Something like a function float[] GetAction(float[] observation)
. I do not need to train the agent. I just need his “final” state and everything need to take an observation array and create the action array.
I believe the best way to do this would be using TensorFlow’s API, more specifically the saved_mode.simple_save
function, documented here. With this, I would be able to load the model in Java using the Java API from TensorFlow. However, I do not know what I should use as inputs
and outputs
for this function. I have tried to better understand PPO2’s code, but I have limited knowledge in these TensorFlow methods and cannot figure it out.
If someone could point me in the right direction I would appreciate it.
Thanks for your help and awesome work on this repo 😉
Issue Analytics
- State:
- Created 4 years ago
- Comments:19
Top GitHub Comments
You are correct! Thanks again! 😄 I didn’t know about the
sess
attribute. I have changed my code for the following, and the weights are consistent in each run (both on the Python side as well as on the Java side):I’ll be testing if results are the same in Java and Python, but since the weights are I’m expecting the results to be as well.
why don’t you use
model.sess
? (I think have to think about what could go wrong with another session)