[Feature Request] Refactor `predict` method in `BasePolicy` class
See original GitHub issueš Feature
A clear and concise description of the feature proposal.
At present the predict
method in the BasePolicy
class contains quite a lot of logic that could be reused to provide similar functionality. In particular, the current logic of the this method is as follows:
- Pre-process NumPy observation and convert it into a PyTorch Tensor.
- Generate a action(s) from the child policy class through the
_predict
method, with these actions in the form of a PyTorch Tensor. - Post-process the actions, including converting the PyTorch Tensor into a NumPy array.
My suggestion is that steps (1) and (3) are refactors into individual functions on the BasePolicy
class, which are then called in the predict
method.
Motivation
I would like to introduce some policy classes for which I can calculate the action probabilities and not the actions themselves. (This is for some work on off-policy estimation that I am doing.)
Letās call this functionality predict_probabilities
, then at present the initial logic of this functionality is identical to step (1) of the predict
method. If the code is refactored as suggested, then both approaches can use the same pre-processing functionality.
Additionally, I think the refactor would generally make the code more readable and easier to extend parts of functionality to other similar uses.
Pitch
I am happy to do a PR for the proposed refactor, so I would like to know whether or not you would be happy with the proposal.
Alternatives
None
Additional context
None
###Ā Checklist
- I have checked that there is no similar issue in the repo (required)
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:20 (5 by maintainers)
Top GitHub Comments
Try putting your code ``` like this ```, that should look nice š
Kind of. Your code is answering the question āwhat is the log-probability of the action it choseā. You need to inspect the
distribution
variable if you want to know probability of picking any one of the actions. The exact code depends on your action space, but for Discrete space this would bedistribution.distribution.probs
(thedistribution.distribution
object is a pytorch distribution object).Nope, maintainers are doing this for their free-time and partially for their work š. Best way to contribute back is by giving comments, spotting errors and best of all: doing PRs to update things!
@ziegenbalg Actually you can make a trick for this with PPO. As @Miffyli said above, you can use the
evaluate_actions
method for the policy object. This example worked for me (I think, maybe @Miffyli sees some error):Note that im trying to get the probability of perform an action (my action space is Binary at this case), therefore I use the
np.ones_like
function.