[rllib] Add an option to define whether a model outputs logits or probs
See original GitHub issueI propose adding an option to define what inputs
represent (logits or probs) when using default action distributions. As I understand, there is an assumption that models always return logits, but this is problematic when trying to suppress specific actions by assigning them a zero probability. A common practice (e.g. in Attention Is All You Need, OpenAI Five) is to set logits to -inf, and then compute softmax on that. This however leads to probs being outputted. When we now take these probs and put into e.g. TorchMultiCategorical
action distribution, the results will be normalized once again and this will break our idea. If we, on the other hand, output logits with -inf values, this will lead to nans when computing entropy.
I know there is an option to define custom action distribution, and I did that, but this leads to changing only a single line in a given action distribution, e.g.:
class TorchMultiCategorical(TorchDistributionWrapper):
"""MultiCategorical distribution for MultiDiscrete action spaces."""
@override(TorchDistributionWrapper)
def __init__(self, inputs, model, input_lens):
super().__init__(inputs, model)
# If input_lens is np.ndarray or list, force-make it a tuple.
inputs_split = self.inputs.split(tuple(input_lens), dim=1)
self.cats = [
torch.distributions.categorical.Categorical(logits=input_) # <===== HERE
for input_ in inputs_split
]
My idea is to have a property, e.g. in model
, which is set based on a config or explicitly in a model class, which tells, whether this model outputs logits or probs.
I am fine preparing PR for this (at least for the PyTorch classes), but I’d like to hear your opinion first. 😃
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (3 by maintainers)
Top GitHub Comments
Feel free to PR this. We’ll help you get this merged.
Please reopen, as PR for this was issued.