question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. ItĀ collects links to all the places you might be looking at while hunting down a tough bug.

And, if youā€™re still stuck at the end, weā€™re happy to hop on a call to see how we can help out.

[Feature Request] Refactor `predict` method in `BasePolicy` class

See original GitHub issue

šŸš€ Feature

A clear and concise description of the feature proposal.

At present the predict method in the BasePolicy class contains quite a lot of logic that could be reused to provide similar functionality. In particular, the current logic of the this method is as follows:

  1. Pre-process NumPy observation and convert it into a PyTorch Tensor.
  2. Generate a action(s) from the child policy class through the _predict method, with these actions in the form of a PyTorch Tensor.
  3. Post-process the actions, including converting the PyTorch Tensor into a NumPy array.

My suggestion is that steps (1) and (3) are refactors into individual functions on the BasePolicy class, which are then called in the predict method.

Motivation

I would like to introduce some policy classes for which I can calculate the action probabilities and not the actions themselves. (This is for some work on off-policy estimation that I am doing.)

Letā€™s call this functionality predict_probabilities, then at present the initial logic of this functionality is identical to step (1) of the predict method. If the code is refactored as suggested, then both approaches can use the same pre-processing functionality.

Additionally, I think the refactor would generally make the code more readable and easier to extend parts of functionality to other similar uses.

Pitch

I am happy to do a PR for the proposed refactor, so I would like to know whether or not you would be happy with the proposal.

Alternatives

None

Additional context

None

###Ā Checklist

  • I have checked that there is no similar issue in the repo (required)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:20 (5 by maintainers)

github_iconTop GitHub Comments

2reactions
Miffylicommented, Sep 5, 2021

(Sorry, the code formatting was messing up, so removed itā€¦)

Try putting your code ``` like this ```, that should look nice šŸ˜ƒ

It seems to work, but itā€™s only giving me one probability, which is fine since I only have two actions so I subtract from 1 to get the other. Does that look right?

Kind of. Your code is answering the question ā€œwhat is the log-probability of the action it choseā€. You need to inspect the distribution variable if you want to know probability of picking any one of the actions. The exact code depends on your action space, but for Discrete space this would be distribution.distribution.probs (the distribution.distribution object is a pytorch distribution object).

Btw, do you guys have a patreon page? Stable-baselines is such an excellent project! Itā€™s taught me a lot about coding/machine learning and itā€™s so straight forward. Love it!

Nope, maintainers are doing this for their free-time and partially for their work šŸ˜ƒ. Best way to contribute back is by giving comments, spotting errors and best of all: doing PRs to update things!

2reactions
EloyAnguianocommented, Aug 15, 2021

Please please make this happen, it would be so nice to see how ā€˜certainā€™ the neural net is of itā€™s action on each step. Iā€™m working on a wind/tides program for myself to classify environmental conditions into ā€˜go kitingā€™ or ā€˜donā€™t go kitingā€™ šŸ˜ƒ I know, not how the gym env supposed to be used, but stable baselines just make it so easy to code even I can do it. I donā€™t understand the mathematics enough underneath it to make a PR for you guys. I tried hacking around some print statements to no avail šŸ˜¦ (maybe someone has a quick and dirty one for PPO).

@ziegenbalg Actually you can make a trick for this with PPO. As @Miffyli said above, you can use the evaluate_actions method for the policy object. This example worked for me (I think, maybe @Miffyli sees some error):

# Coverts lists to tensors
states_tensor = th.from_numpy(np.asmatrix(states))
actions_tensor = th.from_numpy(np.ones_like(np.asmatrix(actions))) # Prob of perform action

values, log_prob, entropy = best_model.policy.evaluate_actions(states_tensor,actions_tensor)

probs = np.exp(log_prob.detach().numpy())

Note that im trying to get the probability of perform an action (my action space is Binary at this case), therefore I use the np.ones_like function.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Can feature requests reveal the refactoring types?
predict the appropriate refactoring, given as input an Open issue description. ... approach consumes a set of feature requests, labeled with 4 classes,Ā ......
Read more >
Class-Level Refactoring Prediction by Ensemble Learning ...
The current researchers aim i to identify the appropriate method(s) or class(s) that needs to be refactored in object-oriented software.
Read more >
Changelog ā€” Stable Baselines 2.10.3a0 documentation
Fixed bug in pretraining method that prevented from calling it twice. ... Refactored test to remove duplicated code; Add pull request templateĀ ...
Read more >
Stable Baselines Documentation - Read the Docs
calling the .predict() method, this frequently leads to better performance. Looking at the training curve (episode.
Read more >
Feature requests-based recommendation of ... - UC Homepages
1, the proposed approach follows the following six key steps to predict the need for refactoring and recommend the required refactorings. First,Ā ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found