question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[rllib] PPO loss affected by LSTM zero padding

See original GitHub issue

When we customize an RNN-based model or just set use_lstm to True, we need to pad zeros to sampled data, as the function TFPolicyGraph._get_loss_inputs_dict does.

However, when calculating the final loss, we need to remove the padded zeros. The current PPO loss directly multiplies the padded advantages and logp_ratio. https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/ppo/ppo_policy_graph.py#L69

The summation is no problem, but for reduce_mean, we will have a large denominator due to the padded zeros.

Is this an implementation error or an approximation? If max_seq_len is a relatively small value, I think the current implementation can be regarded as an approximation for the sake of efficiency.

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:14 (14 by maintainers)

github_iconTop GitHub Comments

1reaction
ericlcommented, Sep 30, 2018

Yeah, I agree moving the splitting logic outside of TF would make the implementation a lot simpler. The existing implementation is brittle since it does so much work using tf.split() and slice indexing internally, when it could be computed ahead of time in python code.

Btw, I found a different bug in the multi-gpu optimizer for RNNs. The batch is shuffled prior to loading, which breaks up the sequences: https://github.com/ray-project/ray/pull/2996

On the partially observed cartpole example: https://github.com/ray-project/ray/blob/3a3782c39fc2e75f53515660b866e7cc501b7704/python/ray/rllib/examples/cartpole_lstm.py Magenta: before fix Blue: after fix (performance is much better – though the value function seems worse?) shuf

I also tried applying sequence masking before the reduce_mean ops, but it didn’t really seem to affect performance (as expected). Since the padding issue doesn’t seem to have much effect on actual performance, I’m inclined to leave it as is for now due to the multi-gpu optimizer complications.

1reaction
ericlcommented, Sep 29, 2018

Ah, it looks like tf.scatter_nd() is the recommended way to do this.

def tf_pad_zeros_scatter(inputs, seq_lens):
    padded_size = tf.reduce_max(seq_lens) * tf.size(seq_lens)
    indices = tf.boolean_mask(
        tf.range(padded_size),
        tf.reshape(tf.sequence_mask(seq_lens), [padded_size]))
    return tf.scatter_nd(
        tf.expand_dims(indices, 1), inputs,
        [padded_size] + inputs.get_shape().as_list()[1:])

This seems to be the fastest

Numpy
Mean latency 0.003065221309661865
Tf while
Mean latency 0.0033680152893066405
Tf scatter_nd
Mean latency 0.00010882377624511719
Read more comments on GitHub >

github_iconTop Results From Across the Web

Models, Preprocessors, and Action Distributions — Ray 2.2.0
Disabling flattening affects: # - SampleCollectors: Have to store possibly nested action structs. # - Models that have the previous action(s) as part...
Read more >
Algorithms — Ray 2.2.0 - the Ray documentation
Defines a configuration class from which a PPO Algorithm can be built. Example. >>> from ray.rllib.algorithms.ppo import PPOConfig ...
Read more >
LSTM Auto Wrapper - RLlib - Ray
Hi All, Can someone clarify to me how the LSTM auto-wrapper and sequence length interacts with the built in PPO when combined with...
Read more >
Sample Collections and Trajectory Views — Ray 2.2.0
Sample collection process implemented by RLlib: The Policy's model tells the Sampler and its SampleCollector object, which data to store and how to...
Read more >
Getting Started with RLlib — Ray 2.2.0 - the Ray documentation
We first create a PPOConfig and add properties to it, like the environment we want to use, or the resources we want to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found