question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[question] How to evaluate PPO2 with MlpLnLstmPolicy trained on SubprocVecEnv having nminibatch > 1?

See original GitHub issue

The code below gives an error message

ValueError: Cannot feed value of shape (1, 4) for Tensor 'input/Ob:0', which has shape '(4, 4)'

It looks like if you are training with 4 environments, the model.predict(obs) method accepts only input of batch size 4.

On the other hand, I have tried that the evaluate function does not accept multiple environments.

AssertionError: You must pass only one environment for evaluation
import gym

from stable_baselines.common.policies import MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import PPO2
from stable_baselines.common.callbacks import EvalCallback

if __name__ == '__main__':
    N_ENV = 4
    NMINIBATCHES = 4

    def _make_env():
        return gym.make('CartPole-v1')

    train_env = SubprocVecEnv([_make_env for _ in range(N_ENV)])
    train_env = VecNormalize(train_env, norm_obs=True, norm_reward=False)

    eval_env = DummyVecEnv([_make_env])
    eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False)

    eval_callback = EvalCallback(eval_env, best_model_save_path='./model',
        log_path='./logs'
        , eval_freq=100
        , deterministic=True
        , render=False
        , n_eval_episodes=1
    )

    model = PPO2(MlpLnLstmPolicy, train_env
        , verbose=1
        , policy_kwargs={'n_lstm': 16}
        , nminibatches=NMINIBATCHES
    )

    model.learn(total_timesteps=int(4000), callback = eval_callback)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7

github_iconTop GitHub Comments

6reactions
Hoiycommented, Apr 4, 2020

To be more specific, I made a new evaluate_lstm_policy() that adds the following to evaluate_policy()

zero_completed_obs = np.zeros((NMINIBATCHES,) + env.observation_space.shape)
zero_completed_obs[0, :] = obs
obs = zero_completed_obs

and make a new EvalCallback that call this evaluate_lstm_policy()

0reactions
alimaicommented, Jul 4, 2020

To be more specific, I made a new evaluate_lstm_policy() that adds the following to evaluate_policy()

zero_completed_obs = np.zeros((NMINIBATCHES,) + env.observation_space.shape)
zero_completed_obs[0, :] = obs
obs = zero_completed_obs

and make a new EvalCallback that call this evaluate_lstm_policy()

good and simple enough to merge it into the project to solve the eval problem

Read more comments on GitHub >

github_iconTop Results From Across the Web

Examples — Stable Baselines 2.10.3a0 documentation
This example demonstrate how to train a recurrent policy and how to test it properly. Warning. One current limitation of recurrent policies is...
Read more >
Stable Baselines Documentation - Read the Docs
from stable_baselines import PPO2 model = PPO2('MlpPolicy', 'CartPole-v1').learn(10000). Fig. 1: Define and train a RL agent in one line of ...
Read more >
Stable Baselines for Reinforcement Learning
Focussed on policy gradient variance reduction and to get an unbiased estimate; Retraces Q-value estimate; Applies clipped importance sampling ...
Read more >
What is a high performing network architecture to use in a ...
Specifically I am training an agent using a PPO2 model. My question is, are there some rules of thumb or best practices in...
Read more >
Python stable_baselines.PPO2 Examples - ProgramCreek.com
test_model_{}.zip'.format(request.node.name) try: # create and train if model_class == PPO2: model = model_class(policy, 'CartPole-v1', nminibatches=1) ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found