question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[RLlib] [Bug] concatenating obs_space with action _space as input space in RNNSAC build_q_model method causes shape mismatch building rnn model

See original GitHub issue

Search before asking

  • I searched the issues and found no similar issues.

Ray Component

RLlib

What happened + What you expected to happen

``Hi, im trying to train a multiagent RNNsac with my custom environment. but the problem is i get a shape mismatch error, i tried to resolve this on my own. but i get that when building the q_model the obs_shape and action space gets concatenated and therefore the model shape gets a shape of action shape + ob shape, and in training the shape mismatch occurs. i cant quite understand why the build_q_model is concatenating the action and obs.

my custom env’s observation space is (9640,) and action space is (4031,) and both are continous values, so with concatenation in q model building i get a shape error. im literally trying the RNNSAC test algorithm to run the model. and also it’s worth mentioning that works perfectly well with multiagent cartpole but it doesn’t work with custom env. ofcourse i tested my custom multi agent env with PPO and PG and its works good!!! the error i get is :

Traceback (most recent call last):
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 773, in setup
    self._init(self.config, self.env_creator)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 873, in _init
    raise NotImplementedError
NotImplementedError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/wildsky/Dropbox/AI-AoI-FeLSA/Simulation/marl_test/testSAC.py", line 116, in <module>
    trainer = sac.RNNSACTrainer(config=config, env="multi_agent_aoi")
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/sac.py", line 187, in __init__
    super().__init__(*args, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 690, in __init__
    super().__init__(config, logger_creator, remote_checkpoint_dir,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/tune/trainable.py", line 122, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 788, in setup
    self.workers = self._make_workers(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 1822, in _make_workers
    return WorkerSet(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 123, in __init__
    self._local_worker = self._make_worker(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 479, in _make_worker
    worker = cls(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 587, in __init__
    self._build_policy_map(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1550, in _build_policy_map
    self.policy_map.create_policy(name, orig_cls, obs_space, act_space,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy_map.py", line 143, in create_policy
    self[policy_id] = class_(observation_space, action_space,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy_template.py", line 280, in __init__
    self._initialize_loss_from_dummy_batch(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 799, in _initialize_loss_from_dummy_batch
    self.compute_actions_from_input_dict(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 294, in compute_actions_from_input_dict
    return self._compute_action_helper(input_dict, state_batches,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 908, in _compute_action_helper
    self.action_distribution_fn(
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_policy.py", line 175, in action_distribution_fn
    _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_model.py", line 100, in get_q_values
    return self._get_q_value(model_out, actions, self.q_net, state_in,
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/agents/sac/rnnsac_torch_model.py", line 91, in _get_q_value
    out, state_out = net(model_out, state_in, seq_lens)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 243, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/recurrent_net.py", line 187, in forward
    wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/fcnet.py", line 124, in forward
    self._features = self._hidden_layers(self._last_flat_in)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/ray/rllib/models/torch/misc.py", line 160, in forward
    return self._model(x)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 96, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/wildsky/My_Venv/DRL/lib/python3.8/site-packages/torch/nn/functional.py", line 1847, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x9640 and 13671x10)

i use this code to train :

from ray.tune.registry import register_env
from ray.rllib.env.multi_agent_env import make_multi_agent
from env_rllib import Environment

from ray.rllib.models import ModelCatalog
import ray.rllib.agents.sac as sac
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import check_compute_single_action, \
    framework_iterator

from rnn_model import TorchRNNModel, RNNModel

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()


    


MultiAgentAOI = make_multi_agent(Environment)

ModelCatalog.register_custom_model("lstm_model", TorchRNNModel)
ModelCatalog.register_custom_model("lstm_model_tf", RNNModel)
register_env("multi_agent_aoi" , lambda x : MultiAgentAOI({"num_agents": 5}))

config = sac.RNNSAC_DEFAULT_CONFIG.copy()
config["num_workers"] = 0  # Run locally.

config["model"] = {
    "max_seq_len": 100,
}
config["env"]= "multi_agent_aoi"
config["policy_model"] = {
    # "custom_model": "lstm_model",
    "fcnet_hiddens": [10],
    "use_lstm": True,
    "lstm_cell_size": 64,

    "lstm_use_prev_action": True,
    "lstm_use_prev_reward": True,
                          }
config["Q_model"] = {
    # "custom_model": "lstm_model",
    "fcnet_hiddens": [10],
    "use_lstm": True,

    "lstm_cell_size": 64,

    "lstm_use_prev_action": True,
    "lstm_use_prev_reward": True,

}

config["prioritized_replay"] = True

config["burn_in"] = 20
config["zero_init_states"] = True

config["lr"] = 5e-4


num_iterations = 1


for _ in framework_iterator(config, frameworks="torch"):
            trainer = sac.RNNSACTrainer(config=config, env="multi_agent_aoi")
            for i in range(num_iterations):
                results = trainer.train()
                print(results)

i dont quite understand this part of building q_model method :

def build_q_model(self, obs_space, action_space, num_outputs,
                      q_model_config, name):
        """Builds one of the (twin) Q-nets used by this SAC.

        Override this method in a sub-class of SACTFModel to implement your
        own Q-nets. Alternatively, simply set `custom_model` within the
        top level SAC `Q_model` config key to make this default implementation
        of `build_q_model` use your custom Q-nets.

        Returns:
            TorchModelV2: The TorchModelV2 Q-net sub-model.
        """
        self.concat_obs_and_actions = False
        if self.discrete:
            input_space = obs_space
        else:
            orig_space = getattr(obs_space, "original_space", obs_space)
            if isinstance(orig_space, Box) and len(orig_space.shape) == 1:
                input_space = Box(
                    float("-inf"),
                    float("inf"),
                    shape=(orig_space.shape[0] + action_space.shape[0], ))
                self.concat_obs_and_actions = True

thanks in advance for your guidance.

Versions / Dependencies

v2.0, v1.9

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:1
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
jamuuscommented, May 15, 2022

Hey,

I’ve run into the same issue.

From the debugging I’ve done it seems to be the calls to get_q_values in action_distribution_fn don’t have actions passed in, so _get_q_value doesn’t concatenate anything with the observation and the dimension mismatch occurs.

    _, q_state_out = model.get_q_values(model_out, states_in["q"], seq_lens)
    if model.twin_q_net:
        _, twin_q_state_out = model.get_twin_q_values(
            model_out, states_in["twin_q"], seq_lens
        )

With #23814 and passing in input_dict['actions'] it progresses further but I’m seeing other seemingly unrelated issues. I also have no idea if those are the actions expected at this point in the algo.

It appears the rnnsac implementation hasn’t been tested with continuous actions, would be good if someone knowledgable of how its supposed to be could take a look, I’ve seen great performance with the torch implementation of the normal SAC so far.

0reactions
avnishncommented, Feb 23, 2022

i cant quite understand why the build_q_model is concatenating the action and obs.

In RL literature, the Q function’s parameters are states and actions. When we represent that as a neural network in code, we concatenate the observations and actions together in order to represent that the Q function is a function of these 2 things.

In the case that your environment is a discrete environment, the q function actually isn’t a q function, and is instead a value function, which only takes as a parameter observation.

so my best guess here is that your issue has something to do with the observations and actions not being concatenated when they’re being passed to the _get_q_value function of the RNNSAC torch model. This is probably because at some point, self.concat_obs_and_actions is being set to false, meaning that its trying to pass only the observation, instead of the observation concatenated with the action to your q function.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Custom RNN Model with Examples - why do they fail? - RLlib
I'm pretty sure the RNN inputs come in the shape [Batch, Time, Feature] . You want outputs of this format as well.
Read more >
How to make the inputs and model have the same shape ...
I set self.price to be a 1D numpy array to make it talk better with Ray RLlib. The creation of the custom environment...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found