[question] How to evaluate PPO2 with MlpLnLstmPolicy trained on SubprocVecEnv having nminibatch > 1?
See original GitHub issueThe code below gives an error message
ValueError: Cannot feed value of shape (1, 4) for Tensor 'input/Ob:0', which has shape '(4, 4)'
It looks like if you are training with 4 environments, the model.predict(obs)
method accepts only input of batch size 4.
On the other hand, I have tried that the evaluate function does not accept multiple environments.
AssertionError: You must pass only one environment for evaluation
import gym
from stable_baselines.common.policies import MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines import PPO2
from stable_baselines.common.callbacks import EvalCallback
if __name__ == '__main__':
N_ENV = 4
NMINIBATCHES = 4
def _make_env():
return gym.make('CartPole-v1')
train_env = SubprocVecEnv([_make_env for _ in range(N_ENV)])
train_env = VecNormalize(train_env, norm_obs=True, norm_reward=False)
eval_env = DummyVecEnv([_make_env])
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False)
eval_callback = EvalCallback(eval_env, best_model_save_path='./model',
log_path='./logs'
, eval_freq=100
, deterministic=True
, render=False
, n_eval_episodes=1
)
model = PPO2(MlpLnLstmPolicy, train_env
, verbose=1
, policy_kwargs={'n_lstm': 16}
, nminibatches=NMINIBATCHES
)
model.learn(total_timesteps=int(4000), callback = eval_callback)
Issue Analytics
- State:
- Created 3 years ago
- Comments:7
Top Results From Across the Web
Examples — Stable Baselines 2.10.3a0 documentation
This example demonstrate how to train a recurrent policy and how to test it properly. Warning. One current limitation of recurrent policies is...
Read more >Stable Baselines Documentation - Read the Docs
from stable_baselines import PPO2 model = PPO2('MlpPolicy', 'CartPole-v1').learn(10000). Fig. 1: Define and train a RL agent in one line of ...
Read more >Stable Baselines for Reinforcement Learning
Focussed on policy gradient variance reduction and to get an unbiased estimate; Retraces Q-value estimate; Applies clipped importance sampling ...
Read more >What is a high performing network architecture to use in a ...
Specifically I am training an agent using a PPO2 model. My question is, are there some rules of thumb or best practices in...
Read more >Python stable_baselines.PPO2 Examples - ProgramCreek.com
test_model_{}.zip'.format(request.node.name) try: # create and train if model_class == PPO2: model = model_class(policy, 'CartPole-v1', nminibatches=1) ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
To be more specific, I made a new
evaluate_lstm_policy()
that adds the following toevaluate_policy()
and make a new
EvalCallback
that call thisevaluate_lstm_policy()
good and simple enough to merge it into the project to solve the eval problem