[rllib] Trainer.compute_action Error with Dict type observation inputs
See original GitHub issueWhat is the problem?
Python: 3.8.5, TensorFlow: tensorflow-gpu 2.0.0 Ray: ray 1.0.1 & ray 0.8.6
I want to reproduce the code of this blog, but I got an error. Action Masking with RLlib
Training Script
Here is the code script. use pip install or_gym first
from or_gym.utils import create_env
from gym import spaces
from ray.rllib.utils import try_import_tf
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models import ModelCatalog
from ray import tune
from ray.rllib import agents
import ray
import or_gym
import numpy as np
env = or_gym.make('Knapsack-v0')
print("Max weight capacity:\t{}kg".format(env.max_weight))
print("Number of items:\t{}".format(env.N))
env_config = {'N': 5,
'max_weight': 15,
'item_weights': np.array([1, 12, 2, 1, 4]),
'item_values': np.array([2, 4, 2, 1, 10]),
'mask': True}
env = or_gym.make('Knapsack-v0', env_config=env_config)
print("Max weight capacity:\t{}kg".format(env.max_weight))
print("Number of items:\t{}".format(env.N))
tf = try_import_tf()
# tf.compat.v1.disable_eager_execution()
class KP0ActionMaskModel(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs,
model_config, name, true_obs_shape=(11,),
action_embed_size=5, *args, **kwargs):
super(KP0ActionMaskModel, self).__init__(obs_space,
action_space, num_outputs, model_config, name,
*args, **kwargs)
self.action_embed_model = FullyConnectedNetwork(
spaces.Box(0, 1, shape=true_obs_shape),
action_space, action_embed_size,
model_config, name + "_action_embedding")
self.register_variables(self.action_embed_model.variables())
def forward(self, input_dict, state, seq_lens):
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
action_embedding, _ = self.action_embed_model({
"obs": input_dict["obs"]["state"]})
intent_vector = tf.expand_dims(action_embedding, 1)
action_logits = tf.math.reduce_sum(avail_actions * intent_vector,
axis=1)
inf_mask = tf.math.maximum(tf.math.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
def value_function(self):
return self.action_embed_model.value_function()
ModelCatalog.register_custom_model('kp_mask', KP0ActionMaskModel)
def register_env(env_name, env_config={}):
env = create_env(env_name)
tune.register_env(env_name, lambda env_name: env(
env_name, env_config=env_config))
register_env('Knapsack-v0', env_config=env_config)
ray.init(ignore_reinit_error=True)
trainer_config = {
"model": {
"custom_model": "kp_mask"
},
"env_config": env_config
}
trainer = agents.ppo.PPOTrainer(env='Knapsack-v0', config=trainer_config)
env = trainer.env_creator('Knapsack-v0')
state = env.state
state['action_mask'][0] = 0
actions = np.array([trainer.compute_action(state) for i in range(10)])
print(actions)
This script works fine in Ray0.8.7, but in Ray1.0.1 rasie Error. Because trainer.compute_action() can’t deal with dict type input
Error
Traceback (most recent call last):
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 60, in check_shape
if not self._obs_space.contains(observation):
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/gym/spaces/box.py", line 128, in contains
return x.shape == self.shape and np.all(x >= self.low) and np.all(x <= self.high)
AttributeError: 'dict' object has no attribute 'shape'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/notebooks/projects/hanyu/ReferProject/MahjongFastPK/test.py", line 96, in <module>
actions = np.array([trainer.compute_action(state) for i in range(10)])
File "/notebooks/projects/hanyu/ReferProject/MahjongFastPK/test.py", line 96, in <listcomp>
actions = np.array([trainer.compute_action(state) for i in range(10)])
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 819, in compute_action
preprocessed = self.workers.local_worker().preprocessors[
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 166, in transform
self.check_shape(observation)
File "/data2/huangcq/miniconda3/envs/majenv/lib/python3.8/site-packages/ray/rllib/models/preprocessors.py", line 66, in check_shape
raise ValueError(
ValueError: ('Observation for a Box/MultiBinary/MultiDiscrete space should be an np.array, not a Python list.', {'action_mask': array([0, 1, 1, 1, 1]), 'avail_actions': array([1., 1., 1., 1., 1.]), 'state': array([ 1, 12, 2, 1, 4, 2, 4, 2, 1, 10, 0])})
Question
The problem is code works fine in Ray 0.8.6, bug in Ray 1.0.1 raise the ValueError.
So, what should i do to use compute_action() dealing with Dict type input in Ray 1.0.1?
Thanks for any help!
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:7 (3 by maintainers)
Top Results From Across the Web
Getting Started with RLlib — Ray 2.2.0 - the Ray documentation
This method preprocesses and filters the observation before passing it to the agent policy. Here is a simple example of testing a trained...
Read more >ray.rllib.algorithms.algorithm — Ray 3.0.0.dev0
Source code for ray.rllib.algorithms.algorithm. from collections import defaultdict import concurrent import ...
Read more >Compute_actions for Trajectory API - RLlib - Ray.io
this is training only. Suppose that I want a customized replay where at the end of the replay I also render my environment, ......
Read more >All Time - Ray
For any questions related to RLlib and reinforcement learning on Ray. ... Trainer.compute_action Error with Dict type observation inputs · RLlib.
Read more >RLlib Training APIs — Ray 0.8.6 documentation
(Type rllib rollout --help to see the available evaluation options.) ... transforms such as one-hot encoding # and flattening of tuple and dict...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Wonderful solution! Thanks to your selfless contribution!
Ok, here is the fix PR: https://github.com/ray-project/ray/pull/12706 As a workaround, could you change the following in your rllib/evaluation/worker_set.py (look for the
if self._remote_workers:if-block)?The problem was that the local-worker (driver) was using the already preprocessed space (got it from the remote-worker) to build its own policy/preprocessot stack. That’s why the necessary DictFlatteningPreprocessor was never built and your Trainer did not do any preprocessing (on your input dict’s observation) prior to sending the data to the Policy.