[rllib] Recurrent models examples don't work if eager mode is on
See original GitHub issueWhat is the problem?
Ray version and other system information (Python version, TensorFlow version, OS): Ray 0.8.0
Reproduction (REQUIRED)
Please provide a script that can be run to reproduce the issue. The script should have no external library dependencies (i.e., use fake or mock data / environments):
If we cannot run your script, we cannot fix your issue.
- I have verified my script runs in a clean environment and reproduces the issue.
- I have verified the issue also occurs with the latest wheels.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
from gym.spaces import Discrete
import numpy as np
import random
import argparse
import ray
from ray import tune
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
parser.add_argument("--stop", type=int, default=90)
class MyKerasRNN(RecurrentTFModelV2):
"""Example of using the Keras functional API to define a RNN model."""
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
hiddens_size=256,
cell_size=64):
super(MyKerasRNN, self).__init__(obs_space, action_space, num_outputs,
model_config, name)
self.cell_size = cell_size
# Define input layers
input_layer = tf.keras.layers.Input(
shape=(None, obs_space.shape[0]), name="inputs")
state_in_h = tf.keras.layers.Input(shape=(cell_size, ), name="h")
state_in_c = tf.keras.layers.Input(shape=(cell_size, ), name="c")
seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
# Preprocess observation with a hidden layer and send to LSTM cell
dense1 = tf.keras.layers.Dense(
hiddens_size, activation=tf.nn.relu, name="dense1")(input_layer)
lstm_out, state_h, state_c = tf.keras.layers.LSTM(
cell_size, return_sequences=True, return_state=True, name="lstm")(
inputs=dense1,
mask=tf.sequence_mask(seq_in),
initial_state=[state_in_h, state_in_c])
# Postprocess LSTM output with another hidden layer and compute values
logits = tf.keras.layers.Dense(
self.num_outputs,
activation=tf.keras.activations.linear,
name="logits")(lstm_out)
values = tf.keras.layers.Dense(
1, activation=None, name="values")(lstm_out)
# Create the RNN model
self.rnn_model = tf.keras.Model(
inputs=[input_layer, seq_in, state_in_h, state_in_c],
outputs=[logits, values, state_h, state_c])
self.register_variables(self.rnn_model.variables)
self.rnn_model.summary()
@override(RecurrentTFModelV2)
def forward_rnn(self, inputs, state, seq_lens):
model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] +
state)
return model_out, [h, c]
@override(ModelV2)
def get_initial_state(self):
return [
np.zeros(self.cell_size, np.float32),
np.zeros(self.cell_size, np.float32),
]
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
class RepeatInitialEnv(gym.Env):
"""Simple env in which the policy learns to repeat the initial observation
seen at timestep 0."""
def __init__(self):
self.observation_space = Discrete(2)
self.action_space = Discrete(2)
self.token = None
self.num_steps = 0
def reset(self):
self.token = random.choice([0, 1])
self.num_steps = 0
return self.token
def step(self, action):
if action == self.token:
reward = 1
else:
reward = -1
self.num_steps += 1
done = self.num_steps > 100
return 0, reward, done, {}
class RepeatAfterMeEnv(gym.Env):
"""Simple env in which the policy learns to repeat a previous observation
token after a given delay."""
def __init__(self, config):
self.observation_space = Discrete(2)
self.action_space = Discrete(2)
self.delay = config["repeat_delay"]
assert self.delay >= 1, "delay must be at least 1"
self.history = []
def reset(self):
self.history = [0] * self.delay
return self._next_obs()
def step(self, action):
if action == self.history[-(1 + self.delay)]:
reward = 1
else:
reward = -1
done = len(self.history) > 100
return self._next_obs(), reward, done, {}
def _next_obs(self):
token = random.choice([0, 1])
self.history.append(token)
return token
if __name__ == "__main__":
ray.init()
args = parser.parse_args()
ModelCatalog.register_custom_model("rnn", MyKerasRNN)
register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
register_env("RepeatInitialEnv", lambda _: RepeatInitialEnv())
tune.run(
args.run,
stop={"episode_reward_mean": args.stop},
config={
"env": args.env,
"eager": True,
"env_config": {
"repeat_delay": 2,
},
"gamma": 0.9,
"num_workers": 0,
"num_envs_per_worker": 20,
"entropy_coeff": 0.001,
"num_sgd_iter": 5,
"vf_loss_coeff": 1e-5,
"model": {
"custom_model": "rnn",
"max_seq_len": 20,
},
})
Issue Analytics
- State:
- Created 4 years ago
- Comments:8 (8 by maintainers)
Top Results From Across the Web
Examples — Ray 2.2.0
This blog post is a brief tutorial on multi-agent RL and its design in RLlib. Functional RL with Keras and TensorFlow Eager: Exploration...
Read more >How To Customize Policies — Ray 2.2.0
This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode. Eager mode makes debugging...
Read more >ray.rllib.policy.tf_policy — Ray 2.2.0 - the Ray documentation
format(model) ) self.model = model # Auto-update model's inference view requirements, if recurrent. if self.model is not None: self.
Read more >ray.rllib.policy.dynamic_tf_policy — Ray 2.1.0
If not specified, a default model will be created. action_sampler_fn: A callable ... Auto-update model's inference view requirements, if recurrent. self.
Read more >ray.rllib.algorithms.algorithm — Ray 3.0.0.dev0
update_from_dict(config.to_dict()) # In case this algo is using a generic config (with no algo_class set), set it # here. if config.algo_class is None:...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yeah, that’s about as far as I got too. Didn’t realize it was just casting errors; I can probably open up a PR to fix this if that’s the remaining error.
Yup, works as well (also w/o eager). I’ll close. See this patch: https://github.com/ray-project/ray/pull/7021