question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[rllib] 'use_lstm' in PPO value function

See original GitHub issue

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
  • Ray installed from (source or binary): source
  • Ray version: 5.0
  • Python version: 3.6.1
  • Exact command to reproduce:

Describe the problem

I’m looking into the source code of PPO + LSTM, but I found i can not use both LSTM and GAE for the value function.

https://github.com/ray-project/ray/blob/079c4e482acd1300d65505543d1511b141e9bd30/python/ray/rllib/agents/ppo/ppo_policy_graph.py#L156-L166

As I’ve seen the paper “Learning Dexterous In-Hand Manipulation” from OpenAI, they successfully used both LSTM and GAE.

Question

  1. Why can’t I use both LSTM and GAE in the framework?

  2. If I remove the vf_config["use_lstm"] = False statement, can i use LSTM for the value function without any problem?

Thank you for help.

Source code / logs

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:12 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
whikwoncommented, Aug 21, 2018

@ericl It works. Thank you so much for the help.

1reaction
ericlcommented, Aug 21, 2018

@whikwon the following works for me (it’s a bit less modular than having two LSTM()s, but I think there are some complications with that approach).

Note that the new LSTM state is now a list of four elements:

diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py
index 894d868..30cec63 100644
--- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py
+++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py
@@ -155,19 +155,7 @@ class PPOPolicyGraph(TFPolicyGraph):
         self.logits = self.model.outputs
         curr_action_dist = dist_cls(self.logits)
         self.sampler = curr_action_dist.sample()
-        if self.config["use_gae"]:
-            vf_config = self.config["model"].copy()
-            # Do not split the last layer of the value function into
-            # mean parameters and standard deviation parameters and
-            # do not make the standard deviations free variables.
-            vf_config["free_log_std"] = False
-            vf_config["use_lstm"] = False
-            with tf.variable_scope("value_function"):
-                self.value_function = ModelCatalog.get_model(
-                    obs_ph, 1, vf_config).outputs
-                self.value_function = tf.reshape(self.value_function, [-1])
-        else:
-            self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
+        self.value_function = self.model.value_function
 
         self.loss_obj = PPOLoss(
             action_space,
class LSTM(Model):
    """Adds a LSTM cell on top of some other model output.

    Uses a linear layer at the end for output.

    Important: we assume inputs is a padded batch of sequences denoted by
        self.seq_lens. See add_time_dimension() for more information.
    """

    def _build_layers(self, inputs, num_outputs, options):
        cell_size = options.get("lstm_cell_size", 256)
        use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
                         distutils.version.LooseVersion("1.0.0"))
        last_layer = add_time_dimension(inputs, self.seq_lens)

        # Setup the LSTM cell
        if use_tf100_api:
            lstm1 = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
            lstm2 = rnn.BasicLSTMCell(cell_size, state_is_tuple=True)
        else:
            lstm1 = rnn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
            lstm2 = rnn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
        self.state_init = [
            np.zeros(lstm1.state_size.c, np.float32),
            np.zeros(lstm1.state_size.h, np.float32),
            np.zeros(lstm2.state_size.c, np.float32),
            np.zeros(lstm2.state_size.h, np.float32)
        ]

        # Setup LSTM inputs
        if self.state_in:
            c1_in, h1_in, c2_in, h2_in = self.state_in
        else:
            c1_in = tf.placeholder(
                tf.float32, [None, lstm1.state_size.c], name="c1")
            h1_in = tf.placeholder(
                tf.float32, [None, lstm1.state_size.h], name="h1")
            c2_in = tf.placeholder(
                tf.float32, [None, lstm2.state_size.c], name="c2")
            h2_in = tf.placeholder(
                tf.float32, [None, lstm2.state_size.h], name="h2")
            self.state_in = [c1_in, h1_in, c2_in, h2_in]

        # Setup LSTM outputs
        if use_tf100_api:
            state1_in = rnn.LSTMStateTuple(c1_in, h1_in)
            state2_in = rnn.LSTMStateTuple(c2_in, h2_in)
        else:
            state1_in = rnn.rnn_cell.LSTMStateTuple(c1_in, h1_in)
            state2_in = rnn.rnn_cell.LSTMStateTuple(c2_in, h2_in)
        lstm1_out, lstm1_state = tf.nn.dynamic_rnn(
            lstm1,
            last_layer,
            sequence_length=self.seq_lens,
            time_major=False,
            dtype=tf.float32)

        with tf.variable_scope("value_function"):
            lstm2_out, lstm2_state = tf.nn.dynamic_rnn(
                lstm2,
                last_layer,
                sequence_length=self.seq_lens,
                time_major=False,
                dtype=tf.float32)

        self.value_function = tf.reshape(
            linear(
                tf.reshape(lstm2_out, [-1, cell_size]),
                1, "vf", normc_initializer(0.01)),
            [-1])

        self.state_out = list(lstm1_state) + list(lstm2_state)

        # Compute outputs
        last_layer = tf.reshape(lstm1_out, [-1, cell_size])
        logits = linear(last_layer, num_outputs, "action",
                        normc_initializer(0.01))

        return logits, last_layer
Read more comments on GitHub >

github_iconTop Results From Across the Web

Models, Preprocessors, and Action Distributions — Ray 2.2.0
More generally, RLlib supports the use of recurrent/attention models for all its policy-gradient algorithms (A3C, PPO, PG, IMPALA), and the necessary sequence ...
Read more >
Ppo add the lstm NN - RLlib - Ray
Hello, I want to add the lstm NN to my PPO agent. ... If by critic you mean the value network then yes,the...
Read more >
Algorithms — Ray 2.2.0 - the Ray documentation
clip_param – PPO clip parameter. vf_clip_param – Clip param for the value function. Note that this is sensitive to the scale of the...
Read more >
RLlib Models, Preprocessors, and Action Distributions
Return the value function estimate for the most recent forward pass. Returns ... import ray import ray.rllib.agents.ppo as ppo from ray.rllib.models import ...
Read more >
How To Customize Policies — Ray 2.2.0
In RLlib, loss functions are defined over batches of trajectory data ... In this example, we'll dive into how PPO is defined within...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found