evaluate_policy() crashes with PPO2 policies trained on vectorized environments [bug]
See original GitHub issueDescribe the bug
After training PPO2 in a vectorized environment with a MLPLSTM policy, evaluate_policy() disallows evalutation with vectorized environments via assert but then crashes when evaluated with a non-vectorized environment. As far as I can tell this means evaluate_policy is incompatible with PPO2 policies trained in vectorized environments. I think this is reasonable to consider this a bug since stable baselines is failing by tensorflow crash rather than by an assert in stable baselines.
If possible, I can try fixing the crash, but it would probably be a bit faster for someone with more understanding of the recurrent policy implementation to determine whether this is something which should be fixed, or it should be patched via assert statement which disallows all PPO2 policies trained in vectorized environments to be used with evaluate_policy.
Code example
from stable_baselines.common.evaluation import evaluate_policy
from stable_baselines.common import make_vec_env
import gym
from stable_baselines import PPO2
env = make_vec_env('CartPole-v1',n_envs=12)
eval_env = gym.make('CartPole-v1')
model = PPO2('MlpLstmPolicy', env, nminibatches=1, verbose=1)
model.learn(10000)
(mean, std) = evaluate_policy(model,eval_env, n_eval_episodes = 10)
#(mean, std) = evaluate_policy(model,env, n_eval_episodes = 10)
output
Traceback (most recent call last):
File "./minimal_example.py", line 15, in <module>
(mean, std) = evaluate_policy(model,eval_env, n_eval_episodes = 10)
File "/home/john/.local/lib/python3.6/site-packages/stable_baselines/common/evaluation.py", line 38, in evaluate_policy
action, state = model.predict(obs, state=state, deterministic=deterministic)
File "/home/john/.local/lib/python3.6/site-packages/stable_baselines/common/base_class.py", line 819, in predict
actions, _, states, _ = self.step(observation, state, mask, deterministic=deterministic)
File "/home/john/.local/lib/python3.6/site-packages/stable_baselines/common/policies.py", line 505, in step
{self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
File "/home/john/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "/home/john/.local/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1149, in _run
str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (1, 4) for Tensor 'input/Ob:0', which has shape '(12, 4)'
System Info ubuntu 18.04 python 3.6.9
Additional context
Issue Analytics
- State:
- Created 2 years ago
- Comments:5
Top GitHub Comments
Yep, looked like that fixed it! Yeah I’ll add a PR with an assertion for that. Thanks for your quick responses; you saved me a lot of time.
(leaving this open until the pull request is ready)