question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Feature Request] Recompute the advantage of a minibatch in ppo

See original GitHub issue

🚀 Feature

According to this paper, recomputing the advantage can be helpful for the PPO performance. The function is provided by tianshou library.

https://github.com/thu-ml/tianshou/blob/655d5fb14fe85ea9da86b441456286fa1f078384/tianshou/policy/modelfree/ppo.py#L107

But I don’t know how to add this in sb3. Some hints about how to do that would be very helpful.

Thanks!

Motivation

I am comparing stable-baselines3, tianshou and rllib for the best performance of PPO.

Pitch

Recompute the advantage in learning ppo.

### Checklist

  • I have checked that there is no similar issue in the repo (required)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:16 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
yangysccommented, May 24, 2021

Hello, I recompute it after each epoch, being consistent with the library tianshou https://github.com/thu-ml/tianshou/blob/655d5fb14fe85ea9da86b441456286fa1f078384/tianshou/policy/modelfree/ppo.py#L107

I pasted the main modification below. Hopefully you can help check if there is any potential problems.

# ppo.py
  def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
 
            # Recompute adv
            with th.no_grad():
                _, last_values, _ = self.policy(obs_as_tensor(self._last_obs, self.device))
                self.rollout_buffer.compute_returns_and_advantage(last_values, dones=self._last_episode_starts)

            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                # TODO: investigate why there is no issue with the gradient
                # if that line is commented (as in SAC)
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
                values = values.flatten()

           
                # Normalize advantage
                advantages = rollout_data.advantages

To support multi-envs, I did what you suggested before, avoid overwriting the variables https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/buffers.py#L441, and reshape them whenever sampling. We don’t have to do this if we only use one env. But reshaping when sampling heavilly slows low the learning process… Do you have a good solution for this? image

    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
        data = (
            self.swap_and_flatten(self.observations)[batch_inds],
            self.swap_and_flatten(self.actions)[batch_inds],
            self.swap_and_flatten(self.values[batch_inds]).flatten(),
            self.swap_and_flatten(self.log_probs[batch_inds]).flatten(),
            self.swap_and_flatten(self.advantages[batch_inds]).flatten(),
            self.swap_and_flatten(self.returns[batch_inds]).flatten(),
            # self.swap_and_flatten(self.to_torch(self.rewards[batch_inds]).flatten()),
            # self.swap_and_flatten(self.to_torch(self.episode_starts[batch_inds]).flatten()),
            # batch_inds

        )
        return RolloutBufferSamples(*tuple(map(self.to_torch, data)))

I also tried recomputing it after sampling bs observations so this is each gradient step yesterday. I returned the sample indices with the sampled data. But anyway unluckily, I didn’t sucessfully make it :<

1reaction
yangysccommented, May 24, 2021

solved…😃 Thanks

Read more comments on GitHub >

github_iconTop Results From Across the Web

Why mini batch size is better than one single "batch" with all ...
The minibatch methodology is a compromise that injects enough noise to each gradient update, while achieving a relative speedy convergence. 1 Bottou, L....
Read more >
A Gentle Introduction to Mini-Batch Gradient Descent and How ...
It works by having the model make predictions on training data and using the error on the predictions to update the model in...
Read more >
The 37 Implementation Details of Proximal Policy Optimization
After calculating the advantages based on GAE, PPO normalizes the advantages by subtracting their mean and dividing them by their standard ...
Read more >
[Q] Using minibatches in PPO/Policy gradient updates - Reddit
In an ideal world you would grab those, and update your network with 1000 experiences, and in the case of PPO, you could...
Read more >
arXiv:1810.02541v9 [cs.LG] 3 Nov 2020
algorithms to achieve this is Proximal Policy Optimization (PPO) ... However, using only positive advantage actions guarantees that.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found