question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[rllib] Trainer.compute_action Error with Dict type observation inputs

See original GitHub issue

What is the problem?

Python: 3.8.5, TensorFlow: tensorflow-gpu 2.0.0 Ray: ray 1.0.1 & ray 0.8.6

I want to reproduce the code of this blog, but I got an error. Action Masking with RLlib

Training Script

Here is the code script. use pip install or_gym first

from or_gym.utils import create_env
from gym import spaces
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models import ModelCatalog
from ray import tune
from ray.rllib import agents
import ray
import or_gym
import numpy as np
env = or_gym.make('Knapsack-v0')

print("Max weight capacity:\t{}kg".format(env.max_weight))
print("Number of items:\t{}".format(env.N))

env_config = {'N': 5,
              'max_weight': 15,
              'item_weights': np.array([1, 12, 2, 1, 4]),
              'item_values': np.array([2, 4, 2, 1, 10]),
              'mask': True}
env = or_gym.make('Knapsack-v0', env_config=env_config)
print("Max weight capacity:\t{}kg".format(env.max_weight))
print("Number of items:\t{}".format(env.N))

tf = try_import_tf()
# tf.compat.v1.disable_eager_execution()


class KP0ActionMaskModel(TFModelV2):

    def __init__(self, obs_space, action_space, num_outputs,
                 model_config, name, true_obs_shape=(11,),
                 action_embed_size=5, *args, **kwargs):

        super(KP0ActionMaskModel, self).__init__(obs_space,
                                                 action_space, num_outputs, model_config, name,
                                                 *args, **kwargs)

        self.action_embed_model = FullyConnectedNetwork(
            spaces.Box(0, 1, shape=true_obs_shape),
            action_space, action_embed_size,
            model_config, name + "_action_embedding")
        self.register_variables(self.action_embed_model.variables())

    def forward(self, input_dict, state, seq_lens):
        avail_actions = input_dict["obs"]["avail_actions"]
        action_mask = input_dict["obs"]["action_mask"]
        action_embedding, _ = self.action_embed_model({
            "obs": input_dict["obs"]["state"]})
        intent_vector = tf.expand_dims(action_embedding, 1)
        action_logits = tf.math.reduce_sum(avail_actions * intent_vector,
                                           axis=1)
        inf_mask = tf.math.maximum(tf.math.log(action_mask), tf.float32.min)
        return action_logits + inf_mask, state

    def value_function(self):
        return self.action_embed_model.value_function()


ModelCatalog.register_custom_model('kp_mask', KP0ActionMaskModel)


def register_env(env_name, env_config={}):
    env = create_env(env_name)
    tune.register_env(env_name, lambda env_name: env(
        env_name, env_config=env_config))


register_env('Knapsack-v0', env_config=env_config)


ray.init(ignore_reinit_error=True)
trainer_config = {
    "model": {
        "custom_model": "kp_mask"
        },
    "env_config": env_config
     }
trainer = agents.ppo.PPOTrainer(env='Knapsack-v0', config=trainer_config)

env = trainer.env_creator('Knapsack-v0')
state = env.state
state['action_mask'][0] = 0


actions = np.array([trainer.compute_action(state) for i in range(10)])

print(actions)

This script works fine in Ray0.8.7, but in Ray1.0.1 rasie Error. Because trainer.compute_action() can’t deal with dict type input

Error

Traceback (most recent call last):
  File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 60, in check_shape
    if not self._obs_space.contains(observation):
  File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/gym/spaces/box.py", line 128, in contains
    return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
AttributeError: 'dict' object has no attribute 'shape'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/notebooks/projects/hanyu/ReferProject/MahjongFastPK/test.py", line 96, in <module>
    actions = np.array([trainer.compute_action(state) for i in range(10)])
  File "/notebooks/projects/hanyu/ReferProject/MahjongFastPK/test.py", line 96, in <listcomp>
    actions = np.array([trainer.compute_action(state) for i in range(10)])
  File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 819, in compute_action
    preprocessed = self.workers.local_worker().preprocessors[
  File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 166, in transform
    self.check_shape(observation)
  File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 66, in check_shape
    raise ValueError(
ValueError: ('Observation for a Box/MultiBinary/MultiDiscrete space should be an np.array, not a Python list.', {'action_mask': array([0, 1, 1, 1, 1]), 'avail_actions': array([1., 1., 1., 1., 1.]), 'state': array([ 1, 12,  2,  1,  4,  2,  4,  2,  1, 10,  0])})

Question

The problem is code works fine in Ray 0.8.6, bug in Ray 1.0.1 raise the ValueError.

So, what should i do to use compute_action() dealing with Dict type input in Ray 1.0.1?

Thanks for any help!

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:1
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
hybugcommented, Dec 9, 2020

Wonderful solution! Thanks to your selfless contribution!

2reactions
sven1977commented, Dec 9, 2020

Ok, here is the fix PR: https://github.com/ray-project/ray/pull/12706 As a workaround, could you change the following in your rllib/evaluation/worker_set.py (look for the if self._remote_workers: if-block)?

            # If num_workers > 0, get the action_spaces and observation_spaces
            # to not be forced to create an Env on the driver.
            if self._remote_workers:
                remote_spaces = ray.get(self.remote_workers(
                )[0].foreach_policy.remote(
                    lambda p, pid: (pid, p.observation_space, p.action_space)))
                spaces = {
                    e[0]: (getattr(e[1], "original_space", e[1]), e[2])
                    for e in remote_spaces
                }
            ...

The problem was that the local-worker (driver) was using the already preprocessed space (got it from the remote-worker) to build its own policy/preprocessot stack. That’s why the necessary DictFlatteningPreprocessor was never built and your Trainer did not do any preprocessing (on your input dict’s observation) prior to sending the data to the Policy.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Getting Started with RLlib — Ray 2.2.0 - the Ray documentation
This method preprocesses and filters the observation before passing it to the agent policy. Here is a simple example of testing a trained...
Read more >
ray.rllib.algorithms.algorithm — Ray 3.0.0.dev0
Source code for ray.rllib.algorithms.algorithm. from collections import defaultdict import concurrent import ...
Read more >
Compute_actions for Trajectory API - RLlib - Ray.io
this is training only. Suppose that I want a customized replay where at the end of the replay I also render my environment, ......
Read more >
All Time - Ray
For any questions related to RLlib and reinforcement learning on Ray. ... Trainer.compute_action Error with Dict type observation inputs · RLlib.
Read more >
RLlib Training APIs — Ray 0.8.6 documentation
(Type rllib rollout --help to see the available evaluation options.) ... transforms such as one-hot encoding # and flattening of tuple and dict...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found