It is recommended to give an example of off policy using the feature extractor
See original GitHub issueImportant Note: We do not do technical support, nor consulting and don’t answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.
If your issue is related to a custom gym environment, please use the custom gym env template.
🐛 Bug
I want to customize the feature extractor. According to the program written in the example, I get the following errors. I have seen: too many errors when customizing policy, a full example for off policy algorithms should be added in user guide #425, this issue, mentioned The off policy network should also use the feature extractor. It is recommended to give an example of off policy using the feature extractor. Thank you! class CustomCombinedExtractor(BaseFeaturesExtractor): def init(self, observation_space: gym.spaces.Dict): # We do not know features-dim here before going over all the items, # so put something dummy for now. PyTorch requires calling # nn.Module.init before adding modules super(CustomCombinedExtractor, self).init(observation_space, features_dim=1)
extractors = {}
total_concat_size = 0
print(observation_space)
#print(observation_space.items(0))
print(observation_space.spaces.items())
exit()
# We need to know size of the output of this extractor,
# so go over all the spaces and compute output feature sizes
for key, subspace in observation_space.spaces.items():
if key == "image":
# We will just downsample one channel of the image by 4x4 and flatten.
# Assume the image is single-channel (subspace.shape[0] == 0)
extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten())
total_concat_size += subspace.shape[1] // 4 * subspace.shape[2] // 4
elif key == "vector":
# Run through a simple MLP
extractors[key] = nn.Linear(subspace.shape[0], 16)
total_concat_size += 16
self.extractors = nn.ModuleDict(extractors)
# Update the features dim manually
self._features_dim = total_concat_size
def forward(self, observations) -> th.Tensor:
encoded_tensor_list = []
# self.extractors contain nn.Modules that do all the processing.
for key, extractor in self.extractors.items():
encoded_tensor_list.append(extractor(observations[key]))
# Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
return th.cat(encoded_tensor_list, dim=1)
policy_kwargs = dict(
features_extractor_class=CustomCombinedExtractor,
share_features_extractor=False,
features_extractor_kwargs=dict(features_dim=128))
#policy_kwargs = dict(activation_fn=th.nn.ReLU,
# net_arch=[dict(pi=[32, 32], vf=[32, 32])])
def get_model(
self,
model_name: str,
#policy: str = "MlpPolicy",
policy: str = "MultiInputPolicy",
policy_kwargs: dict = policy_kwargs,
model_kwargs: dict = None,
verbose: int = 1
) -> Any:
print("set Debug!")
if model_name not in MODELS:
raise NotImplementedError("NotImplementedError")
if model_kwargs is None:
model_kwargs = MODEL_KWARGS[model_name]
if "action_noise" in model_kwargs:
n_actions = self.env.action_space.shape[-1]
model_kwargs["action_noise"] = NOISE[model_kwargs["action_noise"]](
mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)
)
print(model_kwargs)
print(policy, self.env)
print(model_name)
model = MODELS[model_name](
policy=policy,
env=self.env,
tensorboard_log="{}/{}".format(config.TENSORBOARD_LOG_DIR, model_name),
verbose=verbose,
policy_kwargs=policy_kwargs,
**model_kwargs
)
Traceback (most recent call last):
File “C:/Users/Administrator/PycharmProjects/demo/utils/models.py”, line 419, in <module>
model_sac = agent.get_model(“sac”, model_kwargs=SAC_PARAMS)
File “C:/Users/Administrator/PycharmProjects/demo/utils/models.py”, line 328, in get_model
model = MODELS[model_name](
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\sac.py”, line 144, in init
self._setup_model()
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\sac.py”, line 147, in _setup_model
super(SAC, self)._setup_model()
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py”, line 216, in _setup_model
self.policy = self.policy_class( # pytype:disable=not-instantiable
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\policies.py”, line 498, in init
super(MultiInputPolicy, self).init(
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\policies.py”, line 292, in init
self._build(lr_schedule)
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\policies.py”, line 295, in _build
self.actor = self.make_actor()
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\sac\policies.py”, line 348, in make_actor
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\common\policies.py”, line 112, in _update_features_extractor
features_extractor = self.make_features_extractor()
File “C:\ProgramData\Anaconda3\lib\site-packages\stable_baselines3\common\policies.py”, line 118, in make_features_extractor
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
TypeError: init() got an unexpected keyword argument ‘features_dim’
A clear and concise description of what the bug is.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
from stable_baselines3 import ...
Traceback (most recent call last): File ...
Expected behavior
A clear and concise description of what you expected to happen.
### System Info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, docker, source, …)
- GPU models and configuration
- Python version
- PyTorch version
- Gym version
- Versions of any other relevant libraries
You can use sb3.get_system_info()
to print relevant packages info:
import stable_baselines3 as sb3
sb3.get_system_info()
Additional context
Add any other context about the problem here.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Issue Analytics
- State:
- Created a year ago
- Comments:8 (2 by maintainers)
Top GitHub Comments
I would really like to help you, but you should at least take into consideration the remarks I give you. I need:
and this code is not functional because the imports are missing:
Your code is neither minimal nor functional. So I am not able to reproduce your error.
From what I can see, it appears to be a shape-related error. You may have made a mistake in the network specification.
Can you try to provide a minimal and functional code example to reproduce the error. (Remove all your
print
, use a single agent, …) Please also use the markdown code blocks for code. It will be easier for us to help you.