[rllib] RNN model receives empty state/seq_lens when upgraded from 1.0.0 to 1.0.1
See original GitHub issueWhat is the problem?
A classical PyTorch RNN model (without RecurrentNetwork wrapper) does not receive proper state and seq_lens values after upgrading from Ray 1.0.0 to 1.0.1.
Ray version and other system information (Python version, TensorFlow version, OS): Ubuntu 20.04, Ray 1.0.0/1.0.1, PyTorch 1.6/1.7
Reproduction (REQUIRED)
This code works in 1.0.0, but it doesn’t in 1.0.1 with an error:
max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0]
AttributeError: 'NoneType' object has no attribute 'shape'
import ray
import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.models import ModelCatalog, ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils import override
from ray.tune import tune
class CartPoleModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
self.fc1 = nn.Linear(4, 128)
self.fc2 = nn.Linear(128, 128)
self.lstm = nn.LSTM(128, 128, batch_first=True)
self.action = nn.Linear(128, 2)
self.value = nn.Linear(128, 1)
self._value_out = None
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs_flat"]
x = F.relu(self.fc1(obs))
x = F.relu(self.fc2(x))
flat_inputs = input_dict["obs_flat"].float()
max_seq_len = flat_inputs.shape[0] // seq_lens.shape[0]
x_seq = add_time_dimension(
x, max_seq_len=max_seq_len, framework="torch"
)
x, new_state = self.lstm(x_seq, (state[0].unsqueeze(0), state[1].unsqueeze(0)))
action = self.action(x)
value = self.value(x)
self._value_out = value.reshape(-1)
return torch.reshape(action, [-1, self.num_outputs]), [new_state[0].squeeze(0), new_state[1].squeeze(0)]
@override(TorchModelV2)
def value_function(self):
return self._value_out
@override(ModelV2)
def get_initial_state(self):
return [
self.fc1.weight.new(
1, 128
)
.zero_()
.squeeze(0),
self.fc1.weight.new(
1, 128
)
.zero_()
.squeeze(0),
]
ModelCatalog.register_custom_model("CartPoleModel", CartPoleModel)
if __name__ == "__main__":
ray.init()
tune.run(
"PPO",
checkpoint_freq=1,
checkpoint_at_end=True,
stop={"episode_reward_mean": 199.0},
config={
"env": "CartPole-v0",
"framework": "torch",
"model": {
"custom_model": "CartPoleModel",
"max_seq_len": 5,
}
}
)
- I have verified my script runs in a clean environment and reproduces the issue.
- I have verified the issue also occurs with the latest wheels.
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (7 by maintainers)
Top Results From Across the Web
Models, Preprocessors, and Action Distributions — Ray 2.2.0
The following diagram provides a conceptual overview of data flow between different components in RLlib. We start with an Environment , which -...
Read more >Issue with custom LSTMs - RLlib - Ray
Although this may not be unexpected given that the problem seems to be the config not accepting the custom model.
Read more >How To Customize Policies — Ray 2.2.0
This section covers how to build a TensorFlow RLlib policy using tf_policy_template.build_tf_policy() . To start, you first have to define a loss function....
Read more >Source code for ray.rllib.models.tf.recurrent_net
lstm_out) # Create the RNN model self.rnn_model = tf.keras. ... List[TensorType]: """Get the initial recurrent state values for the model.
Read more >RLlib Models, Preprocessors, and Action Distributions
In addition, if you set "model": {"use_lstm": true} , then the model output will be further processed by a LSTM cell. More generally,...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Hi, I’m experiencing the exact same problem as @npitsillos , in my case I’m also using a custom environment, and it is able to train properly using PPO with TorchPolicy with LSTMs, but when testing it, the time I call agent.compute_action(obs) it throws the same error as for @npitsillos . Could someone please help with this? Also, I’ve read above something about the trajectory_view_api and some hard fixes, but I don’t know (and have no clue of) where is this trajectory api located and which lines of code should be changed so that it works. If some of the users who have solved this bug share which file and in which line they made a change, it’d be helpful @iamhatesz @jsuarez5341
same bug…
Can’t make custom LSTM model work…