RLLib: Pytorch + PPO + RNN: KeyError: 'seq_lens' in the batch dictionary
See original GitHub issueWhat is the problem?
Installed ray with the nightly wheel.
I wrote a custom env, model, and action distribution.
I attempt to train it with PPO but there is a key error in one of the internal object used by RLLib (the batch dict with “seq_lens” that is used for masking recurrent model when backpropagating)
2020-02-18 16:11:55,261 ERROR trial_runner.py:513 -- Trial PPO_test_49f2c33a: Error processing event. Traceback (most recent call last): File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 459, in _process_trial result = self.trial_executor.fetch_result(trial) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/ray_trial_executor.py", line 377, in fetch_result result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/worker.py", line 1522, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=13094, ip=10.0.2.217) File "python/ray/_raylet.pyx", line 447, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 425, in ray._raylet.execute_task.function_executor File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 477, in train raise e File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 463, in train result = Trainable.train(self) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trainable.py", line 254, in train result = self._train() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 122, in _train fetches = self.optimizer.step() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step self.standardize_fields) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/sgd.py", line 111, in do_minibatch_sgd }, minibatch.count)))[policy_id] File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 619, in learn_on_batch info_out[pid] = policy.learn_on_batch(batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py", line 100, in learn_on_batch loss_out = self._loss(self, self.model, self.dist_class, train_batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 112, in ppo_surrogate_loss print(train_batch["seq_lens"]) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__ value = dict.__getitem__(self, key)
KeyError: ‘seq_lens’
Reproduction (REQUIRED)
import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.policy import TupleActions
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.tune.registry import register_env
import gym
from gym.spaces import Discrete, Box, Dict, MultiDiscrete
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor
def _make_f32_array(number):
return np.array(number, dtype="float32")
class TorchMultiCategorical(ActionDistribution):
"""MultiCategorical distribution for MultiDiscrete action spaces."""
@override(ActionDistribution)
def __init__(self, inputs, model):
input_lens = model.dist_input_lens
inputs_splitted = inputs.split(input_lens, dim=1)
self.cats = [
torch.distributions.categorical.Categorical(logits=input_)
for input_ in inputs_splitted
]
@override(ActionDistribution)
def sample(self):
arr = [cat.sample() for cat in self.cats]
ret = torch.stack(arr, dim=1)
return ret
@override(ActionDistribution)
def logp(self, actions):
# # If tensor is provided, unstack it into list
if isinstance(actions, torch.Tensor):
actions = torch.unbind(actions, dim=1)
logps = torch.stack([cat.log_prob(act) for cat, act in zip(self.cats, actions)])
return torch.sum(logps, dim=0)
@override(ActionDistribution)
def multi_entropy(self):
return torch.stack([cat.entropy() for cat in self.cats], dim=1)
@override(ActionDistribution)
def entropy(self):
return torch.sum(self.multi_entropy(), dim=1)
@override(ActionDistribution)
def multi_kl(self, other):
return torch.stack(
[
torch.distributions.kl.kl_divergence(cat, oth_cat)
for cat, oth_cat in zip(self.cats, other.cats)
],
dim=1,
)
@override(ActionDistribution)
def kl(self, other):
return torch.sum(self.multi_kl(other), dim=1)
@staticmethod
@override(ActionDistribution)
def required_model_output_shape(action_space, model_config):
return np.sum(action_space.nvec)
class ReproEnv(gym.Env):
def __init__(self, config):
self.cur_pos = 0
self.window_size = config["window_size"]
self.need_reset = False
self.action_space = MultiDiscrete([3, 2, 51, 10, 2])
self.observation_space = Dict(
{
"lob": Box(low=-np.inf, high=np.inf, shape=(self.window_size, 40)),
"unallocated_wealth": Box(low=0, high=1, shape=()),
"taker_fees": Box(low=-1, high=1, shape=()),
"maker_fees": Box(low=-1, high=1, shape=()),
"order": Dict(
{
"side": Discrete(3),
"type": Discrete(2),
"size": Box(low=0, high=1, shape=()),
"price": Box(low=-np.inf, high=np.inf, shape=()),
"filled": Box(low=0, high=1, shape=()),
}
),
"position": Dict(
{
"side": Discrete(3),
"size": Box(low=0, high=1, shape=()),
"entry_price": Box(low=-np.inf, high=np.inf, shape=()),
"unrealized_pnl": Box(low=-100, high=np.inf, shape=()),
}
),
}
)
def reset(self):
self.cur_pos = 0
self.need_reset = False
return self.step([0, 0, 0, 0, 0])[0] # Noop
def step(self, action):
if self.need_reset:
raise Exception("You need to reset this environment!")
self.cur_pos += 1
assert action in self.action_space, action
if self.cur_pos >= 1000:
done = True
else:
done = False
info = {}
if done:
self.need_reset = True
observation = {
"lob": np.zeros((self.window_size, 40)),
"taker_fees": _make_f32_array(0),
"maker_fees": _make_f32_array(0),
"unallocated_wealth": _make_f32_array(0),
"order": {
"side": 0,
"type": 0,
"size": _make_f32_array(0),
"price": _make_f32_array(0),
"filled": _make_f32_array(0),
},
"position": {
"side": 0,
"size": _make_f32_array(0),
"entry_price": _make_f32_array(0),
"unrealized_pnl": _make_f32_array(0),
},
}
assert observation in self.observation_space, observation
return observation, 0, done, info
# return observation
class CNN(nn.Module):
def __init__(self, dropout=0.2):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 2), stride=(1, 2))
self.conv2 = nn.Conv2d(16, 16, kernel_size=(4, 1))
self.conv3 = nn.Conv2d(16, 16, kernel_size=(4, 1))
self.conv4 = nn.Conv2d(16, 32, kernel_size=(1, 2), stride=(1, 2))
self.conv5 = nn.Conv2d(32, 32, kernel_size=(4, 1))
self.conv6 = nn.Conv2d(32, 32, kernel_size=(4, 1))
self.conv7 = nn.Conv2d(32, 64, kernel_size=(1, 10))
self.conv8 = nn.Conv2d(64, 64, kernel_size=(4, 1))
self.conv9 = nn.Conv2d(64, 64, kernel_size=(4, 1))
# Pad to preserve the length in the time domain
self.pad = nn.ZeroPad2d((0, 0, 0, 3))
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = F.leaky_relu(self.conv1(x))
x = F.leaky_relu(self.conv2(self.pad(x)))
x = F.leaky_relu(self.conv3(self.pad(x)))
x = F.leaky_relu(self.conv4(x))
x = F.leaky_relu(self.conv5(self.pad(x)))
x = F.leaky_relu(self.conv6(self.pad(x)))
x = F.leaky_relu(self.conv7(x))
x = F.leaky_relu(self.conv8(self.pad(x)))
x = F.leaky_relu(self.conv9(self.pad(x)))
x = self.dropout(x)
return x
class TestNet(TorchModelV2, nn.Module):
def init_hidden(self, hidden_size):
h0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
c0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
return (h0, c0)
def __init__(self, obs_space, action_space, num_outputs, config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, config, name)
nn.Module.__init__(self)
model_config = config["custom_options"]
print("Model config:")
print(model_config)
dropout = model_config["dropout"]
window_size = model_config["window_size"]
self.cnn = CNN(dropout=dropout)
print(f"Dropout: {dropout}")
print(f"Window size: {window_size}")
# Value function
self._value_branch = nn.Sequential(
nn.Linear(6, 128),
nn.LeakyReLU(),
nn.Linear(128, 128),
nn.LeakyReLU(),
nn.Linear(128, 128),
nn.LeakyReLU(),
nn.Linear(128, 1),
)
# Policy: Signal
self.long_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)
self.short_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)
self.long_025_lstm = nn.LSTM(64, 32, batch_first=True)
self.long_025_fc = nn.Linear(32, 1)
self.long_050_lstm = nn.LSTM(64, 32, batch_first=True)
self.long_050_fc = nn.Linear(32, 1)
self.long_075_lstm = nn.LSTM(64, 32, batch_first=True)
self.long_075_fc = nn.Linear(32, 1)
self.short_025_lstm = nn.LSTM(64, 32, batch_first=True)
self.short_025_fc = nn.Linear(32, 1)
self.short_050_lstm = nn.LSTM(64, 32, batch_first=True)
self.short_050_fc = nn.Linear(32, 1)
self.short_075_lstm = nn.LSTM(64, 32, batch_first=True)
self.short_075_fc = nn.Linear(32, 1)
# Policy: Brain
self.dumb_fc = nn.Linear(6, 68)
self.dist_input_lens = [3, 2, 51, 10, 2]
self._cur_value = None
@override(TorchModelV2)
def get_initial_state(self):
# make hidden states on same device as model
long_lstm_h, long_lstm_c = self.init_hidden(64)
short_lstm_h, short_lstm_c = self.init_hidden(64)
long_025_lstm_h, long_025_lstm_c = self.init_hidden(32)
long_050_lstm_h, long_050_lstm_c = self.init_hidden(32)
long_075_lstm_h, long_075_lstm_c = self.init_hidden(32)
short_025_lstm_h, short_025_lstm_c = self.init_hidden(32)
short_050_lstm_h, short_050_lstm_c = self.init_hidden(32)
short_075_lstm_h, short_075_lstm_c = self.init_hidden(32)
initial_state = [
long_lstm_h,
long_lstm_c,
short_lstm_h,
short_lstm_c,
long_025_lstm_h,
long_025_lstm_c,
long_050_lstm_c,
long_050_lstm_c,
long_075_lstm_h,
long_075_lstm_h,
short_025_lstm_h,
short_025_lstm_c,
short_050_lstm_c,
short_050_lstm_c,
short_075_lstm_h,
short_075_lstm_h,
]
return initial_state
@override(TorchModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
def forward(self, input_dict, hidden_state, seq_lens):
# if seq_lens is None:
# raise Exception("seq_lens is None")
lob = input_dict["obs"]["lob"]
batch_size, window_length, features = lob.size()
# assert list(hidden_state[0].size()) == [1, 1, 64]
# Unpack the hidden_state
long_lstm_h = hidden_state[0]
long_lstm_c = hidden_state[1]
short_lstm_h = hidden_state[2]
short_lstm_c = hidden_state[3]
long_025_lstm_h = hidden_state[4]
long_025_lstm_c = hidden_state[5]
long_050_lstm_h = hidden_state[6]
long_050_lstm_c = hidden_state[7]
long_075_lstm_h = hidden_state[8]
long_075_lstm_c = hidden_state[9]
short_025_lstm_h = hidden_state[10]
short_025_lstm_c = hidden_state[11]
short_050_lstm_h = hidden_state[12]
short_050_lstm_c = hidden_state[13]
short_075_lstm_h = hidden_state[14]
short_075_lstm_c = hidden_state[15]
# Build the tuples
long_lstm_hidden = (long_lstm_h.view(1, -1, 64), long_lstm_c.view(1, -1, 64))
short_lstm_hidden = (short_lstm_h.view(1, -1, 64), short_lstm_c.view(1, -1, 64))
long_025_lstm_hidden = (
long_025_lstm_h.view(1, -1, 32),
long_025_lstm_c.view(1, -1, 32),
)
long_050_lstm_hidden = (
long_050_lstm_h.view(1, -1, 32),
long_050_lstm_c.view(1, -1, 32),
)
long_075_lstm_hidden = (
long_075_lstm_h.view(1, -1, 32),
long_075_lstm_c.view(1, -1, 32),
)
short_025_lstm_hidden = (
short_025_lstm_h.view(1, -1, 32),
short_025_lstm_c.view(1, -1, 32),
)
short_050_lstm_hidden = (
short_050_lstm_h.view(1, -1, 32),
short_050_lstm_c.view(1, -1, 32),
)
short_075_lstm_hidden = (
short_075_lstm_h.view(1, -1, 32),
short_075_lstm_c.view(1, -1, 32),
)
c_in = lob.view(batch_size, 1, window_length, features)
c_out = self.cnn(c_in)
# Embeddings from the CNN, reshaped to be consummed by the LSTM
embeddings = c_out.view(batch_size, 1, -1)
long_out, long_lstm_hidden = self.long_lstm(embeddings, long_lstm_hidden)
short_out, short_lstm_hidden = self.short_lstm(embeddings, short_lstm_hidden)
# Now on to the tail LSTMs
long_025_out, long_025_lstm_hidden = self.long_025_lstm(
F.leaky_relu(long_out), long_025_lstm_hidden
)
long_050_out, long_050_lstm_hidden = self.long_050_lstm(
F.leaky_relu(long_out), long_050_lstm_hidden
)
long_075_out, long_075_lstm_hidden = self.long_075_lstm(
F.leaky_relu(long_out), long_075_lstm_hidden
)
short_025_out, short_025_lstm_hidden = self.short_025_lstm(
F.leaky_relu(short_out), short_025_lstm_hidden
)
short_050_out, short_050_lstm_hidden = self.short_050_lstm(
F.leaky_relu(short_out), short_050_lstm_hidden
)
short_075_out, short_075_lstm_hidden = self.short_075_lstm(
F.leaky_relu(short_out), short_075_lstm_hidden
)
# Reshape the outputs of the tail LSTMs into (batch, hidden_size)
long_025_out = long_025_out.view(batch_size, -1)
long_050_out = long_050_out.view(batch_size, -1)
long_075_out = long_075_out.view(batch_size, -1)
short_025_out = short_025_out.view(batch_size, -1)
short_050_out = short_050_out.view(batch_size, -1)
short_075_out = short_075_out.view(batch_size, -1)
# Fully connected at the end
long_025_q = self.long_025_fc(F.leaky_relu(long_025_out))
long_050_q = self.long_050_fc(F.leaky_relu(long_050_out))
long_075_q = self.long_075_fc(F.leaky_relu(long_075_out))
short_025_q = self.short_025_fc(F.leaky_relu(short_025_out))
short_050_q = self.short_050_fc(F.leaky_relu(short_050_out))
short_075_q = self.short_075_fc(F.leaky_relu(short_075_out))
quantiles = [
long_025_q,
long_050_q,
long_075_q,
short_025_q,
short_050_q,
short_075_q,
]
quantiles = torch.cat(quantiles, dim=1).view(batch_size, 6)
new_hidden_state = [
long_lstm_hidden[0],
long_lstm_hidden[1],
short_lstm_hidden[0],
short_lstm_hidden[1],
long_025_lstm_hidden[0],
long_025_lstm_hidden[1],
long_050_lstm_hidden[0],
long_050_lstm_hidden[1],
long_075_lstm_hidden[0],
long_075_lstm_hidden[1],
short_025_lstm_hidden[0],
short_025_lstm_hidden[1],
short_050_lstm_hidden[0],
short_050_lstm_hidden[1],
short_075_lstm_hidden[0],
short_075_lstm_hidden[1],
]
assert list(new_hidden_state[0].size()) == [
1,
list(new_hidden_state[0].size())[1],
64,
], new_hidden_state[0].size()
# Value function
self._cur_value = self._value_branch(quantiles).squeeze(1)
logits = self.dumb_fc(quantiles)
return logits, new_hidden_state
ModelCatalog.register_custom_action_dist("torchmulticategorical", TorchMultiCategorical)
ModelCatalog.register_custom_model("test", TestNet)
register_env("test", lambda config: ReproEnv(config))
ray.init()
tune.run(
ppo.PPOTrainer,
config={
"num_workers": 1,
"env": "test",
"log_level": "INFO",
"use_pytorch": True,
"num_gpus": 1,
"vf_share_layers": True,
"env_config": {"window_size": 100,},
"model": {
"custom_action_dist": "torchmulticategorical",
"custom_model": "test",
"custom_options": {
"window_size": 100,
"dropout": 0.2,
"use_learned_hidden": True,
},
},
},
)
- I have verified my script runs in a clean environment and reproduces the issue.
- I have verified the issue also occurs with the latest wheels.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:16 (15 by maintainers)
Top GitHub Comments
Awesome, thank you very much Sven! I will try it!
Almost there. Just some flaws now in the PPO loss concerning
valid_mask
. Apologies for the docs mentioning that we do generically support pytorch + LSTMs: We don’t (yet)! There will be a PR (probably tomorrow), which will fix that for at least the standard PG-algos: PPO, PG, A2C/A3C, iff one uses a custom torch Model. Making the “use_lstm” auto-wrapping functionality work will be a follow-up PR.