question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[rllib] Add an option to define whether a model outputs logits or probs

See original GitHub issue

I propose adding an option to define what inputs represent (logits or probs) when using default action distributions. As I understand, there is an assumption that models always return logits, but this is problematic when trying to suppress specific actions by assigning them a zero probability. A common practice (e.g. in Attention Is All You Need, OpenAI Five) is to set logits to -inf, and then compute softmax on that. This however leads to probs being outputted. When we now take these probs and put into e.g. TorchMultiCategorical action distribution, the results will be normalized once again and this will break our idea. If we, on the other hand, output logits with -inf values, this will lead to nans when computing entropy.

I know there is an option to define custom action distribution, and I did that, but this leads to changing only a single line in a given action distribution, e.g.:

class TorchMultiCategorical(TorchDistributionWrapper):
    """MultiCategorical distribution for MultiDiscrete action spaces."""

    @override(TorchDistributionWrapper)
    def __init__(self, inputs, model, input_lens):
        super().__init__(inputs, model)
        # If input_lens is np.ndarray or list, force-make it a tuple.
        inputs_split = self.inputs.split(tuple(input_lens), dim=1)
        self.cats = [
            torch.distributions.categorical.Categorical(logits=input_)  # <===== HERE
            for input_ in inputs_split
        ]

My idea is to have a property, e.g. in model, which is set based on a config or explicitly in a model class, which tells, whether this model outputs logits or probs.

I am fine preparing PR for this (at least for the PyTorch classes), but I’d like to hear your opinion first. 😃

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
sven1977commented, May 21, 2020

Feel free to PR this. We’ll help you get this merged.

0reactions
iamhateszcommented, Nov 30, 2020

Please reopen, as PR for this was issued.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Models, Preprocessors, and Action Distributions — Ray 2.2.0
Models, Preprocessors, and Action Distributions#. The following diagram provides a conceptual overview of data flow between different components in RLlib.
Read more >
Getting Started with RLlib — Ray 2.2.0 - the Ray documentation
Getting Started with RLlib#. At a high level, RLlib provides you with an Algorithm class which holds a policy for environment interaction.
Read more >
Apply preprocessor in custom model - RLlib - Ray
and I have a custom model 1 class DQNModel(TFModelV2): 2 3 def __init__(self, 4 obs_space: Space, 5 act_space: Space, 6 num_outputs: int, 7...
Read more >
ValueError: Expected parameter logits (...) to satisfy the ... - Ray
Today I got a very strange problem with Ray RLlib (ray version 2.0.0dev0). ... my environment a hundred times, as well as my...
Read more >
ray.rllib.algorithms.algorithm — Ray 3.0.0.dev0
[docs] @PublicAPI def __init__( self, config: Optional[AlgorithmConfig] ... API by q1 2023 # Collect worker metrics and add combine them with `results`. if...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found