question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Question] Different results using MultiInputPolicy and MlpPolicy with the same observation data

See original GitHub issue

I am able to train my custom gym environment with good results using a MlpPolicy but when I change the policy to MultiInputPolicy and insert my vector observation array into a single element dictionary I get completely different results. From the tests it looks like I should get the same results as they are both using the FlattenExtractor. This is with a PPO policy.

Is there any guidance on what to check to see why the results are different? Thank you.

The changes made were:

Change the space to a dict:

self._observation_space = spaces.Box(low=-np.ones(self.num_obs) * np.inf, high=np.ones(self.num_obs) * np.inf, dtype=np.float32)

to

self._observation_space = spaces.Dict(
            spaces={
                "vec": spaces.Box(low=-np.ones(self.num_obs) * np.inf, high=np.ones(self.num_obs) * np.inf, dtype=np.float32)
            }
        )

Change the observation to a dict:

self._observation = np.zeros((self.num_envs, self.num_obs), dtype=np.float32)

to

self._observation = {"vec": np.zeros((self.num_envs, self.num_obs), dtype=np.float32)}

Change the policy args from MlpPolicy to MultiInputPolicy:

 model = PPO('MlpPolicy', env, verbose=2, tensorboard_log=saver.data_dir)

to

model = PPO('MultiInputPolicy', env, verbose=2, tensorboard_log=saver.data_dir)

### Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)
  • I have checked my env using the env checker (required)
  • I have provided a minimal working example to reproduce the bug (required)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:8 (4 by maintainers)

github_iconTop GitHub Comments

3reactions
JadenTravnikcommented, May 22, 2021

Hey @jhurlbut, good question.

I wrote out the below to verify that the difference between the initial policies would be the features_extractor module. Hopefully this code represents your use-case, I ignored tensorboard as it shouldn’t change the policy modules (but I haven’t double checked that).

import gym
import numpy as np
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.policies import MultiInputActorCriticPolicy, ActorCriticPolicy

class TestEnv(gym.Env):
    def __init__(self, is_dict=False, num_obs=10):
        self.action_space = gym.spaces.Discrete(1)
        obs_space = gym.spaces.Box(low=-np.ones(num_obs) * np.inf, high=np.ones(num_obs) * np.inf, dtype=np.float32)
        if not is_dict:
            self.observation_space = obs_space
        else:
            self.observation_space = gym.spaces.Dict({"vec_obs": obs_space})

    def reset(self):
        return self.observation_space.sample()

    def step(self, action):
        return self.observation_space.sample(), 0, False, {}

torch.manual_seed(0)
flat_model = PPO("MlpPolicy", TestEnv(is_dict=False), verbose=2)
torch.manual_seed(0)
dict_model = PPO("MultiInputPolicy", TestEnv(is_dict=True), verbose=2)

print("\n", flat_model.policy, dict_model.policy, sep="\n\n")

# asserts that the mlp_extractor weights are the same via the manual seed
assert (dict_model.policy.mlp_extractor.policy_net[0].weight == flat_model.policy.mlp_extractor.policy_net[0].weight).all()
assert (dict_model.policy.mlp_extractor.policy_net[2].weight == flat_model.policy.mlp_extractor.policy_net[2].weight).all()

assert (dict_model.policy.mlp_extractor.value_net[0].weight == flat_model.policy.mlp_extractor.value_net[0].weight).all()
assert (dict_model.policy.mlp_extractor.value_net[2].weight == flat_model.policy.mlp_extractor.value_net[2].weight).all()

Did you run both policies with multiple seeds? My first guess is that its the variability in the models training itself. Could you try multiple seeds and see multiple runs have similar performance?

Otherwise my next guess is that the issues could be in the environment wrappers if you’re using them.

0reactions
araffincommented, May 24, 2021

yes, the environment is using a c++ wrapper for some of the vectorizing processes. Thank you for the helpful advice.

the issue is now solved then 😉

Read more comments on GitHub >

github_iconTop Results From Across the Web

TD3 — Stable Baselines3 1.7.0a8 documentation
Policy class (with both actor and critic) for TD3. MultiInputPolicy. Policy class (with both actor and critic) for TD3 to be used with...
Read more >
Stablebaselines MultiInputpolicies - openai gym
But, I get an error: KeyError: "Error: unknown policy type MultiInputPolicy,the only registed policy type are: ['MlpPolicy', 'CnnPolicy']!".
Read more >
Add the Bootstrapped Dual Policy Iteration algorithm for ...
The main reason I propose to have BDPI in stable-baselines3-contrib is that it is quite different from other algorithms, as it heavily focuses...
Read more >
Training RL agents in stable-baselines3 is easy
Setting the policy to “MlpPolicy” means, that we are giving a state vector as input to our model. There are only 2 other...
Read more >
Stable-Baselines3: Reliable Reinforcement Learning ...
To help with this problem, we present Stable-Baselines3 (SB3), ... Stable-Baselines3 keeps the same easy-to-use API while improving a lot on ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found