[rllib] PPO loss affected by LSTM zero padding
See original GitHub issueWhen we customize an RNN-based model or just set use_lstm
to True
, we need to pad zeros to sampled data, as the function TFPolicyGraph._get_loss_inputs_dict
does.
However, when calculating the final loss, we need to remove the padded zeros. The current PPO loss directly multiplies the padded advantages
and logp_ratio
. https://github.com/ray-project/ray/blob/master/python/ray/rllib/agents/ppo/ppo_policy_graph.py#L69
The summation is no problem, but for reduce_mean
, we will have a large denominator due to the padded zeros.
Is this an implementation error or an approximation? If max_seq_len
is a relatively small value, I think the current implementation can be regarded as an approximation for the sake of efficiency.
Issue Analytics
- State:
- Created 5 years ago
- Comments:14 (14 by maintainers)
Top Results From Across the Web
Models, Preprocessors, and Action Distributions — Ray 2.2.0
Disabling flattening affects: # - SampleCollectors: Have to store possibly nested action structs. # - Models that have the previous action(s) as part...
Read more >Algorithms — Ray 2.2.0 - the Ray documentation
Defines a configuration class from which a PPO Algorithm can be built. Example. >>> from ray.rllib.algorithms.ppo import PPOConfig ...
Read more >LSTM Auto Wrapper - RLlib - Ray
Hi All, Can someone clarify to me how the LSTM auto-wrapper and sequence length interacts with the built in PPO when combined with...
Read more >Sample Collections and Trajectory Views — Ray 2.2.0
Sample collection process implemented by RLlib: The Policy's model tells the Sampler and its SampleCollector object, which data to store and how to...
Read more >Getting Started with RLlib — Ray 2.2.0 - the Ray documentation
We first create a PPOConfig and add properties to it, like the environment we want to use, or the resources we want to...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yeah, I agree moving the splitting logic outside of TF would make the implementation a lot simpler. The existing implementation is brittle since it does so much work using tf.split() and slice indexing internally, when it could be computed ahead of time in python code.
Btw, I found a different bug in the multi-gpu optimizer for RNNs. The batch is shuffled prior to loading, which breaks up the sequences: https://github.com/ray-project/ray/pull/2996
On the partially observed cartpole example: https://github.com/ray-project/ray/blob/3a3782c39fc2e75f53515660b866e7cc501b7704/python/ray/rllib/examples/cartpole_lstm.py Magenta: before fix Blue: after fix (performance is much better – though the value function seems worse?)
I also tried applying sequence masking before the reduce_mean ops, but it didn’t really seem to affect performance (as expected). Since the padding issue doesn’t seem to have much effect on actual performance, I’m inclined to leave it as is for now due to the multi-gpu optimizer complications.
Ah, it looks like
tf.scatter_nd()
is the recommended way to do this.This seems to be the fastest