question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

RLLib: Pytorch + PPO + RNN: KeyError: 'seq_lens' in the batch dictionary

See original GitHub issue

What is the problem?

Installed ray with the nightly wheel. I wrote a custom env, model, and action distribution. I attempt to train it with PPO but there is a key error in one of the internal object used by RLLib (the batch dict with “seq_lens” that is used for masking recurrent model when backpropagating) 2020-02-18 16:11:55,261 ERROR trial_runner.py:513 -- Trial PPO_test_49f2c33a: Error processing event. Traceback (most recent call last): File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 459, in _process_trial result = self.trial_executor.fetch_result(trial) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/ray_trial_executor.py", line 377, in fetch_result result = ray.get(trial_future[0], DEFAULT_GET_TIMEOUT) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/worker.py", line 1522, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError(KeyError): ray::PPO.train() (pid=13094, ip=10.0.2.217) File "python/ray/_raylet.pyx", line 447, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 425, in ray._raylet.execute_task.function_executor File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 477, in train raise e File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 463, in train result = Trainable.train(self) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/tune/trainable.py", line 254, in train result = self._train() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 122, in _train fetches = self.optimizer.step() File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/optimizers/sync_samples_optimizer.py", line 71, in step self.standardize_fields) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/sgd.py", line 111, in do_minibatch_sgd }, minibatch.count)))[policy_id] File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 619, in learn_on_batch info_out[pid] = policy.learn_on_batch(batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/policy/torch_policy.py", line 100, in learn_on_batch loss_out = self._loss(self, self.model, self.dist_class, train_batch) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/agents/ppo/ppo_torch_policy.py", line 112, in ppo_surrogate_loss print(train_batch["seq_lens"]) File "/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ray/rllib/utils/tracking_dict.py", line 22, in __getitem__ value = dict.__getitem__(self, key) KeyError: ‘seq_lens’

Reproduction (REQUIRED)

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.models.action_dist import ActionDistribution
from ray.rllib.policy.policy import TupleActions
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.tune.registry import register_env
import gym
from gym.spaces import Discrete, Box, Dict, MultiDiscrete
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.nn import Parameter
from torch import Tensor


def _make_f32_array(number):
    return np.array(number, dtype="float32")


class TorchMultiCategorical(ActionDistribution):
    """MultiCategorical distribution for MultiDiscrete action spaces."""

    @override(ActionDistribution)
    def __init__(self, inputs, model):
        input_lens = model.dist_input_lens
        inputs_splitted = inputs.split(input_lens, dim=1)
        self.cats = [
            torch.distributions.categorical.Categorical(logits=input_)
            for input_ in inputs_splitted
        ]

    @override(ActionDistribution)
    def sample(self):
        arr = [cat.sample() for cat in self.cats]
        ret = torch.stack(arr, dim=1)
        return ret

    @override(ActionDistribution)
    def logp(self, actions):
        # # If tensor is provided, unstack it into list
        if isinstance(actions, torch.Tensor):
            actions = torch.unbind(actions, dim=1)
        logps = torch.stack([cat.log_prob(act) for cat, act in zip(self.cats, actions)])
        return torch.sum(logps, dim=0)

    @override(ActionDistribution)
    def multi_entropy(self):
        return torch.stack([cat.entropy() for cat in self.cats], dim=1)

    @override(ActionDistribution)
    def entropy(self):
        return torch.sum(self.multi_entropy(), dim=1)

    @override(ActionDistribution)
    def multi_kl(self, other):
        return torch.stack(
            [
                torch.distributions.kl.kl_divergence(cat, oth_cat)
                for cat, oth_cat in zip(self.cats, other.cats)
            ],
            dim=1,
        )

    @override(ActionDistribution)
    def kl(self, other):
        return torch.sum(self.multi_kl(other), dim=1)

    @staticmethod
    @override(ActionDistribution)
    def required_model_output_shape(action_space, model_config):
        return np.sum(action_space.nvec)


class ReproEnv(gym.Env):
    def __init__(self, config):
        self.cur_pos = 0
        self.window_size = config["window_size"]
        self.need_reset = False
        self.action_space = MultiDiscrete([3, 2, 51, 10, 2])
        self.observation_space = Dict(
            {
                "lob": Box(low=-np.inf, high=np.inf, shape=(self.window_size, 40)),
                "unallocated_wealth": Box(low=0, high=1, shape=()),
                "taker_fees": Box(low=-1, high=1, shape=()),
                "maker_fees": Box(low=-1, high=1, shape=()),
                "order": Dict(
                    {
                        "side": Discrete(3),
                        "type": Discrete(2),
                        "size": Box(low=0, high=1, shape=()),
                        "price": Box(low=-np.inf, high=np.inf, shape=()),
                        "filled": Box(low=0, high=1, shape=()),
                    }
                ),
                "position": Dict(
                    {
                        "side": Discrete(3),
                        "size": Box(low=0, high=1, shape=()),
                        "entry_price": Box(low=-np.inf, high=np.inf, shape=()),
                        "unrealized_pnl": Box(low=-100, high=np.inf, shape=()),
                    }
                ),
            }
        )

    def reset(self):
        self.cur_pos = 0
        self.need_reset = False
        return self.step([0, 0, 0, 0, 0])[0]  # Noop

    def step(self, action):
        if self.need_reset:
            raise Exception("You need to reset this environment!")
        self.cur_pos += 1
        assert action in self.action_space, action
        if self.cur_pos >= 1000:
            done = True
        else:
            done = False
        info = {}
        if done:
            self.need_reset = True
        observation = {
            "lob": np.zeros((self.window_size, 40)),
            "taker_fees": _make_f32_array(0),
            "maker_fees": _make_f32_array(0),
            "unallocated_wealth": _make_f32_array(0),
            "order": {
                "side": 0,
                "type": 0,
                "size": _make_f32_array(0),
                "price": _make_f32_array(0),
                "filled": _make_f32_array(0),
            },
            "position": {
                "side": 0,
                "size": _make_f32_array(0),
                "entry_price": _make_f32_array(0),
                "unrealized_pnl": _make_f32_array(0),
            },
        }
        assert observation in self.observation_space, observation
        return observation, 0, done, info
        # return observation


class CNN(nn.Module):
    def __init__(self, dropout=0.2):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1, 2), stride=(1, 2))
        self.conv2 = nn.Conv2d(16, 16, kernel_size=(4, 1))
        self.conv3 = nn.Conv2d(16, 16, kernel_size=(4, 1))

        self.conv4 = nn.Conv2d(16, 32, kernel_size=(1, 2), stride=(1, 2))
        self.conv5 = nn.Conv2d(32, 32, kernel_size=(4, 1))
        self.conv6 = nn.Conv2d(32, 32, kernel_size=(4, 1))

        self.conv7 = nn.Conv2d(32, 64, kernel_size=(1, 10))
        self.conv8 = nn.Conv2d(64, 64, kernel_size=(4, 1))
        self.conv9 = nn.Conv2d(64, 64, kernel_size=(4, 1))
        # Pad to preserve the length in the time domain
        self.pad = nn.ZeroPad2d((0, 0, 0, 3))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(self.pad(x)))
        x = F.leaky_relu(self.conv3(self.pad(x)))
        x = F.leaky_relu(self.conv4(x))
        x = F.leaky_relu(self.conv5(self.pad(x)))
        x = F.leaky_relu(self.conv6(self.pad(x)))
        x = F.leaky_relu(self.conv7(x))
        x = F.leaky_relu(self.conv8(self.pad(x)))
        x = F.leaky_relu(self.conv9(self.pad(x)))
        x = self.dropout(x)
        return x


class TestNet(TorchModelV2, nn.Module):
    def init_hidden(self, hidden_size):
        h0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
        c0 = self._value_branch[0].weight.new(1, hidden_size).zero_()
        return (h0, c0)

    def __init__(self, obs_space, action_space, num_outputs, config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, config, name)
        nn.Module.__init__(self)
        model_config = config["custom_options"]
        print("Model config:")
        print(model_config)
        dropout = model_config["dropout"]
        window_size = model_config["window_size"]
        self.cnn = CNN(dropout=dropout)
        print(f"Dropout: {dropout}")
        print(f"Window size: {window_size}")

        # Value function
        self._value_branch = nn.Sequential(
            nn.Linear(6, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 1),
        )
        # Policy: Signal
        self.long_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)
        self.short_lstm = nn.LSTM(64 * window_size, 64, batch_first=True)

        self.long_025_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_025_fc = nn.Linear(32, 1)

        self.long_050_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_050_fc = nn.Linear(32, 1)

        self.long_075_lstm = nn.LSTM(64, 32, batch_first=True)
        self.long_075_fc = nn.Linear(32, 1)

        self.short_025_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_025_fc = nn.Linear(32, 1)

        self.short_050_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_050_fc = nn.Linear(32, 1)

        self.short_075_lstm = nn.LSTM(64, 32, batch_first=True)
        self.short_075_fc = nn.Linear(32, 1)
        # Policy: Brain
        self.dumb_fc = nn.Linear(6, 68)

        self.dist_input_lens = [3, 2, 51, 10, 2]
        self._cur_value = None

    @override(TorchModelV2)
    def get_initial_state(self):
        # make hidden states on same device as model
        long_lstm_h, long_lstm_c = self.init_hidden(64)
        short_lstm_h, short_lstm_c = self.init_hidden(64)

        long_025_lstm_h, long_025_lstm_c = self.init_hidden(32)
        long_050_lstm_h, long_050_lstm_c = self.init_hidden(32)
        long_075_lstm_h, long_075_lstm_c = self.init_hidden(32)

        short_025_lstm_h, short_025_lstm_c = self.init_hidden(32)
        short_050_lstm_h, short_050_lstm_c = self.init_hidden(32)
        short_075_lstm_h, short_075_lstm_c = self.init_hidden(32)

        initial_state = [
            long_lstm_h,
            long_lstm_c,
            short_lstm_h,
            short_lstm_c,
            long_025_lstm_h,
            long_025_lstm_c,
            long_050_lstm_c,
            long_050_lstm_c,
            long_075_lstm_h,
            long_075_lstm_h,
            short_025_lstm_h,
            short_025_lstm_c,
            short_050_lstm_c,
            short_050_lstm_c,
            short_075_lstm_h,
            short_075_lstm_h,
        ]
        return initial_state

    @override(TorchModelV2)
    def value_function(self):
        assert self._cur_value is not None, "must call forward() first"
        return self._cur_value

    def forward(self, input_dict, hidden_state, seq_lens):
        # if seq_lens is None:
        #     raise Exception("seq_lens is None")
        lob = input_dict["obs"]["lob"]
        batch_size, window_length, features = lob.size()
        # assert list(hidden_state[0].size()) == [1, 1, 64]
        # Unpack the hidden_state
        long_lstm_h = hidden_state[0]
        long_lstm_c = hidden_state[1]
        short_lstm_h = hidden_state[2]
        short_lstm_c = hidden_state[3]

        long_025_lstm_h = hidden_state[4]
        long_025_lstm_c = hidden_state[5]
        long_050_lstm_h = hidden_state[6]
        long_050_lstm_c = hidden_state[7]
        long_075_lstm_h = hidden_state[8]
        long_075_lstm_c = hidden_state[9]

        short_025_lstm_h = hidden_state[10]
        short_025_lstm_c = hidden_state[11]
        short_050_lstm_h = hidden_state[12]
        short_050_lstm_c = hidden_state[13]
        short_075_lstm_h = hidden_state[14]
        short_075_lstm_c = hidden_state[15]
        # Build the tuples
        long_lstm_hidden = (long_lstm_h.view(1, -1, 64), long_lstm_c.view(1, -1, 64))
        short_lstm_hidden = (short_lstm_h.view(1, -1, 64), short_lstm_c.view(1, -1, 64))
        long_025_lstm_hidden = (
            long_025_lstm_h.view(1, -1, 32),
            long_025_lstm_c.view(1, -1, 32),
        )
        long_050_lstm_hidden = (
            long_050_lstm_h.view(1, -1, 32),
            long_050_lstm_c.view(1, -1, 32),
        )
        long_075_lstm_hidden = (
            long_075_lstm_h.view(1, -1, 32),
            long_075_lstm_c.view(1, -1, 32),
        )
        short_025_lstm_hidden = (
            short_025_lstm_h.view(1, -1, 32),
            short_025_lstm_c.view(1, -1, 32),
        )
        short_050_lstm_hidden = (
            short_050_lstm_h.view(1, -1, 32),
            short_050_lstm_c.view(1, -1, 32),
        )
        short_075_lstm_hidden = (
            short_075_lstm_h.view(1, -1, 32),
            short_075_lstm_c.view(1, -1, 32),
        )

        c_in = lob.view(batch_size, 1, window_length, features)
        c_out = self.cnn(c_in)
        # Embeddings from the CNN, reshaped to be consummed by the LSTM
        embeddings = c_out.view(batch_size, 1, -1)

        long_out, long_lstm_hidden = self.long_lstm(embeddings, long_lstm_hidden)
        short_out, short_lstm_hidden = self.short_lstm(embeddings, short_lstm_hidden)
        # Now on to the tail LSTMs
        long_025_out, long_025_lstm_hidden = self.long_025_lstm(
            F.leaky_relu(long_out), long_025_lstm_hidden
        )
        long_050_out, long_050_lstm_hidden = self.long_050_lstm(
            F.leaky_relu(long_out), long_050_lstm_hidden
        )
        long_075_out, long_075_lstm_hidden = self.long_075_lstm(
            F.leaky_relu(long_out), long_075_lstm_hidden
        )

        short_025_out, short_025_lstm_hidden = self.short_025_lstm(
            F.leaky_relu(short_out), short_025_lstm_hidden
        )
        short_050_out, short_050_lstm_hidden = self.short_050_lstm(
            F.leaky_relu(short_out), short_050_lstm_hidden
        )
        short_075_out, short_075_lstm_hidden = self.short_075_lstm(
            F.leaky_relu(short_out), short_075_lstm_hidden
        )
        # Reshape the outputs of the tail LSTMs into (batch, hidden_size)
        long_025_out = long_025_out.view(batch_size, -1)
        long_050_out = long_050_out.view(batch_size, -1)
        long_075_out = long_075_out.view(batch_size, -1)

        short_025_out = short_025_out.view(batch_size, -1)
        short_050_out = short_050_out.view(batch_size, -1)
        short_075_out = short_075_out.view(batch_size, -1)
        # Fully connected at the end
        long_025_q = self.long_025_fc(F.leaky_relu(long_025_out))
        long_050_q = self.long_050_fc(F.leaky_relu(long_050_out))
        long_075_q = self.long_075_fc(F.leaky_relu(long_075_out))

        short_025_q = self.short_025_fc(F.leaky_relu(short_025_out))
        short_050_q = self.short_050_fc(F.leaky_relu(short_050_out))
        short_075_q = self.short_075_fc(F.leaky_relu(short_075_out))
        quantiles = [
            long_025_q,
            long_050_q,
            long_075_q,
            short_025_q,
            short_050_q,
            short_075_q,
        ]
        quantiles = torch.cat(quantiles, dim=1).view(batch_size, 6)
        new_hidden_state = [
            long_lstm_hidden[0],
            long_lstm_hidden[1],
            short_lstm_hidden[0],
            short_lstm_hidden[1],
            long_025_lstm_hidden[0],
            long_025_lstm_hidden[1],
            long_050_lstm_hidden[0],
            long_050_lstm_hidden[1],
            long_075_lstm_hidden[0],
            long_075_lstm_hidden[1],
            short_025_lstm_hidden[0],
            short_025_lstm_hidden[1],
            short_050_lstm_hidden[0],
            short_050_lstm_hidden[1],
            short_075_lstm_hidden[0],
            short_075_lstm_hidden[1],
        ]
        assert list(new_hidden_state[0].size()) == [
            1,
            list(new_hidden_state[0].size())[1],
            64,
        ], new_hidden_state[0].size()
        # Value function
        self._cur_value = self._value_branch(quantiles).squeeze(1)
        logits = self.dumb_fc(quantiles)
        return logits, new_hidden_state


ModelCatalog.register_custom_action_dist("torchmulticategorical", TorchMultiCategorical)
ModelCatalog.register_custom_model("test", TestNet)
register_env("test", lambda config: ReproEnv(config))

ray.init()
tune.run(
    ppo.PPOTrainer,
    config={
        "num_workers": 1,
        "env": "test",
        "log_level": "INFO",
        "use_pytorch": True,
        "num_gpus": 1,
        "vf_share_layers": True,
        "env_config": {"window_size": 100,},
        "model": {
            "custom_action_dist": "torchmulticategorical",
            "custom_model": "test",
            "custom_options": {
                "window_size": 100,
                "dropout": 0.2,
                "use_learned_hidden": True,
            },
        },
    },
)

  • I have verified my script runs in a clean environment and reproduces the issue.
  • I have verified the issue also occurs with the latest wheels.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:2
  • Comments:16 (15 by maintainers)

github_iconTop GitHub Comments

1reaction
janblumenkampcommented, Mar 30, 2020

Awesome, thank you very much Sven! I will try it!

1reaction
sven1977commented, Mar 28, 2020

Almost there. Just some flaws now in the PPO loss concerning valid_mask. Apologies for the docs mentioning that we do generically support pytorch + LSTMs: We don’t (yet)! There will be a PR (probably tomorrow), which will fix that for at least the standard PG-algos: PPO, PG, A2C/A3C, iff one uses a custom torch Model. Making the “use_lstm” auto-wrapping functionality work will be a follow-up PR.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Models, Preprocessors, and Action Distributions — Ray 2.2.0
Models, Preprocessors, and Action Distributions#. The following diagram provides a conceptual overview of data flow between different components in RLlib.
Read more >
RLlib: Abstractions for Distributed Reinforcement Learning
Large-scale tests: We evaluate the performance of RLlib on Evolution Strategies (ES), Proximal Policy Optimization. (PPO), and A3C, comparing against ...
Read more >
RLlib trainer common config - Every little gist
For example, PPO further # divides the train batch into minibatches for multi-epoch SGD. "rollout_fragment_length": 200, # Deprecated; renamed to ...
Read more >
Policy network of PPO in Rllib - Stack Overflow
"lstm_use_prev_reward": False, # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..). "_time_major": False, # == Attention Nets ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found