[Feature Request] Better KL Divergence Approximation
See original GitHub issueImportant Note: We do not do technical support, nor consulting and don’t answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.
🚀 Feature
Use a better estimate for the KL Divergence in the PPO algorithm.
The estimator I propose is the reverse estimator here:
This has a lower variance since the extra term is negatively correlated and also is positive semi-definite (see motivation section).
Motivation
In the PPO algorithm, there is a KL limit method used as a final block to large updates to the policy in a single timestep. The line used is: approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy())
. This is an unbiased estimator, but it has large variance since it can take on negative values (as opposed to the actual KL Divergence measure)! This can cause problems in the check: if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
since this doesn’t consider the large negative values that approx_kl_divs
could take.
Pitch
I want to replace the KL Divergence equation currently used in the PPO algorithm with the better approximation described above.
NOTE: This is basically a 1 line change.
Alternatives
N/A
Additional context
### Checklist
- I have checked that there is no similar issue in the repo (required)
Issue Analytics
- State:
- Created 2 years ago
- Comments:13 (8 by maintainers)
Top GitHub Comments
See this article for some good stuff on forward vs reverse KL: https://dibyaghosh.com/blog/probability/kldivergence.html
Sure! Here’s a blog post on the topic by John Schulman: http://joschu.net/blog/kl-approx.html