[RLlib] [Bug] concatenating obs_space with action _space as input space in RNNSAC build_q_model method causes shape mismatch building rnn model
See original GitHub issueSearch before asking
- I searched the issues and found no similar issues.
Ray Component
RLlib
What happened + What you expected to happen
``Hi, im trying to train a multiagent RNNsac with my custom environment. but the problem is i get a shape mismatch error, i tried to resolve this on my own. but i get that when building the q_model the obs_shape and action space gets concatenated and therefore the model shape gets a shape of action shape + ob shape, and in training the shape mismatch occurs. i cant quite understand why the build_q_model is concatenating the action and obs.
my custom env’s observation space is (9640,) and action space is (4031,) and both are continous values, so with concatenation in q model building i get a shape error. im literally trying the RNNSAC test algorithm to run the model. and also it’s worth mentioning that works perfectly well with multiagent cartpole but it doesn’t work with custom env. ofcourse i tested my custom multi agent env with PPO and PG and its works good!!! the error i get is :
Traceback (most recent call last):
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 773, in setup
self._init(self.config, self.env_creator)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 873, in _init
raise NotImplementedError
NotImplementedError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/wildsky/Dropbox/AI-AoI-FeLSA/Simulation/marl_test/testSAC.py", line 116, in <module>
trainer = sac.RNNSACTrainer(config=config, env="multi_agent_aoi")
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/sac.py", line 187, in __init__
super().__init__(*args, **kwargs)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 690, in __init__
super().__init__(config, logger_creator, remote_checkpoint_dir,
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/tune/trainable.py", line 122, in __init__
self.setup(copy.deepcopy(self.config))
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 788, in setup
self.workers = self._make_workers(
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 1822, in _make_workers
return WorkerSet(
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 123, in __init__
self._local_worker = self._make_worker(
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 479, in _make_worker
worker = cls(
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 587, in __init__
self._build_policy_map(
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1550, in _build_policy_map
self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 143, in create_policy
self[policy_id] = class_(observation_space, action_space,
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy_template.py", line 280, in __init__
self._initialize_loss_from_dummy_batch(
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 799, in _initialize_loss_from_dummy_batch
self.compute_actions_from_input_dict(
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 294, in compute_actions_from_input_dict
return self._compute_action_helper(input_dict, state_batches,
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
return func(self, *a, **k)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 908, in _compute_action_helper
self.action_distribution_fn(
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_policy.py", line 175, in action_distribution_fn
_, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_model.py", line 100, in get_q_values
return self._get_q_value(model_out, actions, self.q_net, state_in,
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_model.py", line 91, in _get_q_value
out, state_out = net(model_out, state_in, seq_lens)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
res = self.forward(restored, state or [], seq_lens)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/recurrent_net.py", line 187, in forward
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/fcnet.py", line 124, in forward
self._features = self._hidden_layers(self._last_flat_in)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/misc.py", line 160, in forward
return self._model(x)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 96, in forward
return F.linear(input, self.weight, self.bias)
File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/functional.py", line 1847, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x9640 and 13671x10)
i use this code to train :
from ray.tune.registry import register_env
from ray.rllib.env.multi_agent_env import make_multi_agent
from env_rllib import Environment
from ray.rllib.models import ModelCatalog
import ray.rllib.agents.sac as sac
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_compute_single_action, \
framework_iterator
from rnn_model import TorchRNNModel, RNNModel
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
MultiAgentAOI = make_multi_agent(Environment)
ModelCatalog.register_custom_model("lstm_model", TorchRNNModel)
ModelCatalog.register_custom_model("lstm_model_tf", RNNModel)
register_env("multi_agent_aoi" , lambda x : MultiAgentAOI({"num_agents": 5}))
config = sac.RNNSAC_DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Run locally.
config["model"] = {
"max_seq_len": 100,
}
config["env"]= "multi_agent_aoi"
config["policy_model"] = {
# "custom_model": "lstm_model",
"fcnet_hiddens": [10],
"use_lstm": True,
"lstm_cell_size": 64,
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
}
config["Q_model"] = {
# "custom_model": "lstm_model",
"fcnet_hiddens": [10],
"use_lstm": True,
"lstm_cell_size": 64,
"lstm_use_prev_action": True,
"lstm_use_prev_reward": True,
}
config["prioritized_replay"] = True
config["burn_in"] = 20
config["zero_init_states"] = True
config["lr"] = 5e-4
num_iterations = 1
for _ in framework_iterator(config, frameworks="torch"):
trainer = sac.RNNSACTrainer(config=config, env="multi_agent_aoi")
for i in range(num_iterations):
results = trainer.train()
print(results)
i dont quite understand this part of building q_model method :
def build_q_model(self, obs_space, action_space, num_outputs,
q_model_config, name):
"""Builds one of the (twin) Q-nets used by this SAC.
Override this method in a sub-class of SACTFModel to implement your
own Q-nets. Alternatively, simply set `custom_model` within the
top level SAC `Q_model` config key to make this default implementation
of `build_q_model` use your custom Q-nets.
Returns:
TorchModelV2: The TorchModelV2 Q-net sub-model.
"""
self.concat_obs_and_actions = False
if self.discrete:
input_space = obs_space
else:
orig_space = getattr(obs_space, "original_space", obs_space)
if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
input_space = Box(
float("-inf"),
float("inf"),
shape=(orig_space.shape[0] + action_space.shape[0], ))
self.concat_obs_and_actions = True
thanks in advance for your guidance.
Versions / Dependencies
v2.0, v1.9
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:5 (2 by maintainers)
Top GitHub Comments
Hey,
I’ve run into the same issue.
From the debugging I’ve done it seems to be the calls to
get_q_values
inaction_distribution_fn
don’t have actions passed in, so_get_q_value
doesn’t concatenate anything with the observation and the dimension mismatch occurs.With #23814 and passing in
input_dict['actions']
it progresses further but I’m seeing other seemingly unrelated issues. I also have no idea if those are the actions expected at this point in the algo.It appears the rnnsac implementation hasn’t been tested with continuous actions, would be good if someone knowledgable of how its supposed to be could take a look, I’ve seen great performance with the torch implementation of the normal SAC so far.
In RL literature, the Q function’s parameters are states and actions. When we represent that as a neural network in code, we concatenate the observations and actions together in order to represent that the Q function is a function of these 2 things.
In the case that your environment is a discrete environment, the q function actually isn’t a q function, and is instead a value function, which only takes as a parameter observation.
so my best guess here is that your issue has something to do with the observations and actions not being concatenated when they’re being passed to the
_get_q_value
function of the RNNSAC torch model. This is probably because at some point,self.concat_obs_and_actions
is being set to false, meaning that its trying to pass only the observation, instead of the observation concatenated with the action to your q function.