[BUG] GymWrapper does not work with nested observation gym.spaces.Dict
See original GitHub issueDescribe the bug
Hi All,
First of all: thanks for the great work here!
I think I have encountered a bug in the GymWrapper in torchrl.envs.libs.gym.GymWrapper
. When I use a gym.Env
with an observation space with nested gym.spaces.Dict
, a KeyError will be thrown since the GymLikeEnv.read_obs()
function does only add “next_” to the first level of Dict but not to nested sub Dicts:
observations = {"next_" + key: value for key, value in observations.items()}
Since _gym_to_torchrl_spec_transform()
in torchrl.envs.libs.gym
ends “next_” in a recursive call to all sub Dicts, the key is missing the necessary “next_”. Nested Dict observation spaces are often used (https://www.gymlibrary.dev/api/spaces/#dict), so I guess this is required to work properly.
To Reproduce
#!/usr/bin/env python
from torchrl.envs.libs.gym import GymWrapper
from gym import spaces, Env
import numpy as np
class CustomGym(Env):
def __init__(self):
self.action_space = spaces.Discrete(5)
self.observation_space = spaces.Dict(
{
'sensor_1': spaces.Box(low=0, high=255, shape=(5, 5, 3), dtype=np.uint8),
'sensor_2': spaces.Box(low=0, high=255, shape=(5, 5, 3), dtype=np.uint8),
'sensor_3': spaces.Box(np.array([-2, -1, -5, 0]), np.array([2, 1, 30, 1]), dtype=np.float32),
'sensor_4': spaces.Dict({'sensor_41': spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32),
'sensor_42': spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32),
'sensor_43': spaces.Box(low=0, high=100, shape=(1,), dtype=np.float32)})
}
)
def reset(self):
return self.observation_space.sample()
if __name__ == '__main__':
env = CustomGym()
env = GymWrapper(env)
Reason and Possible fixes
The issue can be fixed by adding a recursive function call to rename also nested observation space Dicts in GymLikeEnv.read_obs()
correctly by adding “next_”:
def read_obs(
self, observations: Union[Dict[str, Any], torch.Tensor, np.ndarray]
) -> Dict[str, Any]:
"""Reads an observation from the environment and returns an observation compatible with the output TensorDict.
Args:
observations (observation under a format dictated by the inner env): observation to be read.
"""
if isinstance(observations, dict):
def rename(obs):
return {
"next_" + key: rename(value) if isinstance(value, dict) else value
for key, value in obs.items()
}
observations = rename(observations)
if not isinstance(observations, (TensorDict, dict)):
key = list(self.observation_spec.keys())[0]
observations = {key: observations}
observations = self.observation_spec.encode(observations)
return observations
The style checker required to not use lambda functions, otherwise the fix could also be as simple as
rename = lambda obs: {
"next_" + key: rename(value) if isinstance(value, dict) else value
for key, value in obs.items()
}
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Issue Analytics
- State:
- Created a year ago
- Comments:12 (8 by maintainers)
Top GitHub Comments
Yes sure, I’ll take care of it 😃 Thanks for the feedback!
Hey! I see your point. We’re thinking about redesigning this API. I will open a PR with that shortly, but I’d be glad to get your thoughts about it.
First I think the
"next_obs"
is messy and makes it hard to get the tensordict of the next step. Second it does not scale well with other problems (e.g. MCTS or planners in general where we explore many different possible actions for a single state). Finally it requires for the users to pay attention to name the obs in the specs with the"next"
prefix which they might as well forget and find cumbersome.Here’s what I would see: Before:
env.step
returnsWe would change that in:
That way,
step_mdp
just needs to dotensordict = tensordict["step"].clone(recurse=False)
(we clone it, otherwise the original tensordict will keep track of the whole trajectory!) If you likes the previous API you can just dotensordict.flatten_keys("_")
.So in your case you’d have this
Thoughts?
cc @shagunsodhani (by the way it’s funny that we were just talking about that feature a couple of hours ago and @raphajaner came with a very similar idea!)