question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[question] How to save a trained PPO2 agent to use in a Java program?

See original GitHub issue

I am using a PPO2 agent to train on a custom environment. I use the save function to store everything in a .pkl in the callback function, similar to the example from the Colab notebook.

def callback(_locals, _globals):
    """
    Callback called at each step (for DQN an others) or after n steps (see ACER or PPO2)
    :param _locals: (dict)
    :param _globals: (dict)
    """
    global n_steps, best_mean_reward, saving_interval, pickle_dir

    # Print stats every X calls
    if (n_steps + 1) % saving_interval == 0:
        # Evaluate policy training performance
        x, y = ts2xy(load_results(log_dir), 'timesteps')
        if len(x) > 0:
            mean_reward = np.mean(y[-100:])
            logger.info("{} timesteps".format(x[-1]))
            logger.info(
                "Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(best_mean_reward, mean_reward))

            # New best model, you could save the agent here
            if mean_reward > best_mean_reward:
                best_mean_reward = mean_reward
                # Example for saving best model
                logger.info("Saving new best model")
                _locals['self'].save(pickle_dir + 'ppo2_best_model.pkl')
    n_steps += 1
    # Returning False will stop training early
    return True

What I would like to do is extract from the .pkl file only what is necessary to take an observation and return an action. I would use this data in a Java program to get the action that I need without having to use Python. Something like a function float[] GetAction(float[] observation). I do not need to train the agent. I just need his “final” state and everything need to take an observation array and create the action array.

I believe the best way to do this would be using TensorFlow’s API, more specifically the saved_mode.simple_save function, documented here. With this, I would be able to load the model in Java using the Java API from TensorFlow. However, I do not know what I should use as inputs and outputs for this function. I have tried to better understand PPO2’s code, but I have limited knowledge in these TensorFlow methods and cannot figure it out.

If someone could point me in the right direction I would appreciate it.

Thanks for your help and awesome work on this repo 😉

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:19

github_iconTop GitHub Comments

3reactions
josealeixopccommented, May 23, 2019

You are correct! Thanks again! 😄 I didn’t know about the sess attribute. I have changed my code for the following, and the weights are consistent in each run (both on the Python side as well as on the Java side):

def generate_checkpoint_from_model(model_path, checkpoint_name):
    model = PPO2.load(model_path)

    with model.graph.as_default():
        if os.path.exists(checkpoint_name):
            shutil.rmtree(checkpoint_name)

        tf.saved_model.simple_save(model.sess, checkpoint_name, inputs={"obs": model.act_model.obs_ph},
                                   outputs={"action": model.action_ph})

I’ll be testing if results are the same in Java and Python, but since the weights are I’m expecting the results to be as well.

1reaction
araffincommented, May 22, 2019

why don’t you use model.sess? (I think have to think about what could go wrong with another session)

Read more comments on GitHub >

github_iconTop Results From Across the Web

PPO2 — Stable Baselines 2.10.3a0 documentation
Train a PPO agent on CartPole-v1 using 4 processes. ... saving and loading model = PPO2.load("ppo2_cartpole") # Enjoy trained agent obs = env.reset()...
Read more >
Proximal Policy Optimization Tutorial (Part 1/2: Actor-Critic ...
I'll be showing how to implement a Reinforcement Learning algorithm known as Proximal Policy Optimization (PPO) for teaching an AI agent…
Read more >
Stable Baselines Documentation - Read the Docs
Here is a quick example of how to train and run PPO2 on a cartpole environment: ... 1: Define and train a RL...
Read more >
Stable Baselines Tutorial - Gym wrappers, saving and loading ...
The goal here is to create a wrapper that will monitor the training progress, storing both the episode reward (sum of reward for...
Read more >
Stable baselines saving PPO model and retraining it again
The way you saved the model is correct. The training is not a monotonous process: it can also show much worse results after...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found