How to use Dict Spaces? `AttributeError: 'Box' object has no attribute 'spaces'`
See original GitHub issueš Bug
Further to the issue I brought up here over at Gym, I seem to be having some issues getting the Dict
observation space to āworkā properly.
Iāve posted a minimal reproducible example below with my custom environment, and have also made a Colab notebook that be copied from here for reproducibility:
https://colab.research.google.com/drive/1QpSxOA8mVSUja89U9e5bgcMelIiuNArz?usp=sharing
Iāve created a simple Roulette environment that should allow the agent to place some money on each number of a roulette wheel (simple implementation to help me understand). Iāve defined the action and observation spaces as such:
# Spaces
# Each number on roulette board can have 0-3 units placed on it
self.action_space = gym.spaces.MultiDiscrete([3 for _ in range(37)], dtype=int)
# We're going to keep track of how many times each number shows up
# while we're playing, plus our current bankroll and the max
# table betting limit so the agent knows how much $ in total is allowed
# to be placed on the table. Going to use a Dict space for this.
self.observation_space = gym.spaces.Dict(
{
"0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int),
"max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
},
)
As you can see, the agent should be able to place up to 1000 units on any combination of the numbers (multidiscrete), and weāre going to += 1 to the respective number after each spin of the wheel, using the number as the key to the observation dict.
Iāve implemented the env = gym.wrappers.FlattenObservation(env) line as I read that that is needed when using Dict observation spaces, however Iām currently getting the error in the posted traceback.
When I try to run the suggested check_env(env)
with the FLATTENED environment, I get this error:
env = gym.wrappers.FlattenObservation(Roulette_Environment())
# env = Roulette_Environment()
from stable_baselines3.common.env_checker import check_env
# It will check your custom environment and output additional warnings if needed
check_env(env)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-6-efbf74f2cf71>](https://localhost:8080/#) in <module>
3 from stable_baselines3.common.env_checker import check_env
4 # It will check your custom environment and output additional warnings if needed
----> 5 check_env(env)
2 frames
[/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in check_env(env, warn, skip_render_check)
300
301 # ============ Check the returned values ===============
--> 302 _check_returned_values(env, observation_space, action_space)
303
304 # ==== Check the render method and the declared render modes ====
[/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in _check_returned_values(env, observation_space, action_space)
140 """
141 # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
--> 142 obs = env.reset()
143
144 if isinstance(observation_space, spaces.Dict):
[/usr/local/lib/python3.7/dist-packages/gym/core.py](https://localhost:8080/#) in reset(self, **kwargs)
377 def reset(self, **kwargs):
378 """Resets the environment, returning a modified observation using :meth:`self.observation`."""
--> 379 obs, info = self.env.reset(**kwargs)
380 return self.observation(obs), info
381
ValueError: too many values to unpack (expected 2)
Then when I try it with the UNFLATTENED environment, I get this:
#env = gym.wrappers.FlattenObservation(Roulette_Environment())
env = Roulette_Environment()
from stable_baselines3.common.env_checker import check_env
# It will check your custom environment and output additional warnings if needed
check_env(env)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in _check_returned_values(env, observation_space, action_space)
154 try:
--> 155 _check_obs(obs[key], observation_space.spaces[key], "reset")
156 except AssertionError as e:
3 frames
[/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in _check_obs(obs, observation_space, method_name)
109 elif _is_numpy_array_space(observation_space):
--> 110 assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array"
111
AssertionError: The observation returned by `reset()` method must be a numpy array
The above exception was the direct cause of the following exception:
AssertionError Traceback (most recent call last)
[<ipython-input-7-3cc4f8f118c0>](https://localhost:8080/#) in <module>
3 from stable_baselines3.common.env_checker import check_env
4 # It will check your custom environment and output additional warnings if needed
----> 5 check_env(env)
[/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in check_env(env, warn, skip_render_check)
300
301 # ============ Check the returned values ===============
--> 302 _check_returned_values(env, observation_space, action_space)
303
304 # ==== Check the render method and the declared render modes ====
[/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/env_checker.py](https://localhost:8080/#) in _check_returned_values(env, observation_space, action_space)
155 _check_obs(obs[key], observation_space.spaces[key], "reset")
156 except AssertionError as e:
--> 157 raise AssertionError(f"Error while checking key={key}: " + str(e)) from e
158 else:
159 _check_obs(obs, observation_space, "reset")
AssertionError: Error while checking key=0: The observation returned by `reset()` method must be a numpy array
So it suggests that the reset method needs to return an array, not a gym.spaces.Dict
? But I just donāt know.
Iād like to utilize the Dict
space for both action and observation spaces, but need some direction on how to get it working in my simple example, so run the colab notebook, and when things get installed the first time, hit Restart and Run All to see the error.
Code example
import os, sys
if not os.path.isdir('/usr/local/lib/python3.7/dist-packages/stable_baselines3'):
!pip3 install stable_baselines3
!pip3 install -U gym
print("\n\n\n Stable Baselines3 has been installed, Restart and Run All now. DO NOT factory reset, or you'll have to start over\n")
sys.exit(0)
from random import randint
from numpy import inf, float32, array, int32, int64, concatenate
import gym
from gym.wrappers import FlattenObservation
from stable_baselines3 import A2C, DQN, PPO, DDPG, HER, SAC, TD3
print(gym.__version__)
import stable_baselines3
print(stable_baselines3.__version__)
"""Roulette environment class"""
class Roulette_Environment(gym.Env):
metadata = {'render.modes': ['human', 'text']}
"""Initialize the environment"""
def __init__(self):
super(Roulette_Environment, self).__init__()
# Some global variables
self.max_table_limit = 1000
self.initial_bankroll = 2000
# Spaces
# Each number on roulette board can have 0-1000 units placed on it
# self.action_space = gym.spaces.Box(low=0, high=1000, shape=(37,), dtype=int)
self.action_space = gym.spaces.MultiDiscrete([3 for _ in range(37)], dtype=int)
# We're going to keep track of how many times each number shows up
# while we're playing, plus our current bankroll and the max
# table betting limit so the agent knows how much $ in total is allowed
# to be placed on the table. Going to use a Dict space for this.
self.observation_space = gym.spaces.Dict(
{
"0": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"1": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"2": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"3": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"4": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"5": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"6": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"7": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"8": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"9": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"10": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"11": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"12": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"13": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"14": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"15": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"16": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"17": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"18": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"19": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"20": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"21": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"22": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"23": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"24": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"25": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"26": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"27": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"28": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"29": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"30": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"31": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"32": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"33": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"34": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"35": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"36": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
"current_bankroll": gym.spaces.Box(low=-inf, high=inf, shape=(1,), dtype=int),
"max_table_limit": gym.spaces.Box(low=0, high=inf, shape=(1,), dtype=int),
},
)
# # TEST - Doesn't work either...
# for num in range(37):
# self.observation_space[str(num)] = gym.spaces.Discrete(1000)
"""Reset the Environment"""
def reset(self):
self.current_bankroll = self.initial_bankroll
self.done = False
# Take a sample from the observation_space to modify the values of
self.current_state = self.observation_space.sample()
# Reset each number being tracked throughout gameplay to 0
for i in range(0, 37):
self.current_state[str(i)] = 0
# Reset our globals
self.current_state['current_bankroll'] = self.initial_bankroll
self.current_state['max_table_limit'] = self.max_table_limit
# self.current_state = dict(self.current_state).values()
# self.current_state = list(self.current_state.values())
return self.current_state
"""Step Through the Environment"""
def step(self, action):
# # Convert actions to ints cuz they show up as floats,
# # even when defined as ints in the environment.
# # https://github.com/openai/gym/issues/3107
# for i in range(len(action)):
# action[i] = int(action[i])
self.current_action = action
# Subtract your bets from bankroll
sum_of_bets = sum([bet for bet in self.current_action])
# Spin the wheel
self.current_number = randint(a=0, b=36)
# Update the current state
self.current_state['current_bankroll'] = self.current_bankroll
self.current_state[str(self.current_number)] += 1
self.current_state = array(dict(self.current_state).values())
# Make sure we're allowed to place the proposed bet
if sum_of_bets > self.max_table_limit or sum_of_bets > self.current_bankroll:
return self.current_state, 0, self.done, {}
# Calculate payout/reward
self.reward = 36 * self.current_action[self.current_number] - sum_of_bets
self.current_bankroll += self.reward
# If we've doubled our money, or lost our money
if self.current_bankroll >= self.initial_bankroll * 2 or self.current_bankroll <= 0:
self.done = True
# self.current_state = FlattenObservation(self.current_state)
# print(self.current_state)
return self.current_state, self.reward, self.done, {}
"""Render the Environment"""
def render(self, mode='text'):
# Text rendering
if mode == "text":
print(f'Bets Placed: {self.current_action}')
print(f'Number rolled: {self.current_number}')
print(f'Reward: {self.reward}')
print(f'New Bankroll: {self.current_bankroll}')
# #env = gym.wrappers.FlattenObservation(Roulette_Environment())
# env = Roulette_Environment()
# from stable_baselines3.common.env_checker import check_env
# # It will check your custom environment and output additional warnings if needed
# check_env(env)
# env = Roulette_Environment()
env = gym.wrappers.FlattenObservation(Roulette_Environment())
model = A2C('MultiInputPolicy', env, verbose=1)
model.learn(total_timesteps=10000)
# obs = Roulette_Environment()
# print(FlattenObservation(obs))
# obs = obs.reset()
# print(obs)
obs = env.reset()
for i in range(1000):
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
Relevant log output / Error message
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-9-3e7a3ee77e74> in <module>
2 env = gym.wrappers.FlattenObservation(Roulette_Environment())
3
----> 4 model = A2C('MultiInputPolicy', env, verbose=1)
5 model.learn(total_timesteps=10000)
6
4 frames
/usr/local/lib/python3.7/dist-packages/stable_baselines3/a2c/a2c.py in __init__(self, policy, env, learning_rate, n_steps, gamma, gae_lambda, ent_coef, vf_coef, max_grad_norm, rms_prop_eps, use_rms_prop, use_sde, sde_sample_freq, normalize_advantage, tensorboard_log, create_eval_env, policy_kwargs, verbose, seed, device, _init_setup_model)
124
125 if _init_setup_model:
--> 126 self._setup_model()
127
128 def train(self) -> None:
/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/on_policy_algorithm.py in _setup_model(self)
126 self.lr_schedule,
127 use_sde=self.use_sde,
--> 128 **self.policy_kwargs # pytype:disable=not-instantiable
129 )
130 self.policy = self.policy.to(self.device)
/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/policies.py in __init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, sde_net_arch, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class, optimizer_kwargs)
816 normalize_images,
817 optimizer_class,
--> 818 optimizer_kwargs,
819 )
820
/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/policies.py in __init__(self, observation_space, action_space, lr_schedule, net_arch, activation_fn, ortho_init, use_sde, log_std_init, full_std, sde_net_arch, use_expln, squash_output, features_extractor_class, features_extractor_kwargs, normalize_images, optimizer_class, optimizer_kwargs)
457 self.ortho_init = ortho_init
458
--> 459 self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
460 self.features_dim = self.features_extractor.features_dim
461
/usr/local/lib/python3.7/dist-packages/stable_baselines3/common/torch_layers.py in __init__(self, observation_space, cnn_output_dim)
256
257 total_concat_size = 0
--> 258 for key, subspace in observation_space.spaces.items():
259 if is_image_space(subspace):
260 extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim)
AttributeError: 'Box' object has no attribute 'spaces'
### System Info
_No response_
### Checklist
- [X] I have checked that there is no similar [issue](https://github.com/DLR-RM/stable-baselines3/issues) in the repo
- [X] I have read the [documentation](https://stable-baselines3.readthedocs.io/en/master/)
- [X] I have provided a minimal working example to reproduce the bug
- [X] I have checked my env using the env checker
- [X] I've used the [markdown code blocks](https://help.github.com/en/articles/creating-and-highlighting-code-blocks) for both code and stack traces.
Issue Analytics
- State:
- Created 10 months ago
- Comments:14 (1 by maintainers)
Top GitHub Comments
The
check_env
function is quite explicit about the error: the reset method returns a dictionary whose values are integers, while they should be arrays:You must therefore modify the code so that the returned observation is a dict whose values are arrays and not integers.
[Offtopic]
True and true. But ācustom environmentā does mean āpersonalized assistance serviceā. Our goal is to make sure that our discussion benefits the maximum number of people. This is why we need to work on a general code. For example: is a 37-key dictionary necessary to reproduce your error? Probably not, so we prefer a dict with only 1 key, because it is more readable. Does the reward necessarily have to be calculated with this level of complexity to reproduce the error? Probably not, so we prefer a constant reward at 0. Etc. But itās not my job to do this, itās your job to prepare your question.
People of very different levels post issues (check around other issues). I meant to be indulgent and not condescending, sorry you felt that way.
Here, the goal is not to fix your code, but the general problem that leads to the error you encounter. Thatās why Iām giving a general answer (not vague), so that it can benefit everyone.
For future readers:
the env checker only checks that the returned values matches what you declared, both in shape and data type (for instance if you return a float32 instead of a float64, the env checker will return an error even though the code might run), using the
observation_space.contains()
method.Then you should definitely investigate the difference in the data type returned, if there was no difference, there would be the same error. The good news for you is that you now have both an example that works and one that throws an error, so by progressively changing the one that works, you should be able to isolate the issue.
the whole point of a minimal example to reproduce the error is not for us to debug your error for you, but to help you isolate what might be wrong and check that there is no bug in SB3 that should be solved.
as explicitly mentioned in the issue template and in the readme: āImportant Note: We do not do technical support, nor consulting and donāt answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.ā
in other term, we provide help to understand how to use SB3 and isolate potential issues that comes from SB3, but we donāt debug issues when it comes from outside SB3 (which is the case here apparently).