How to do transfer learning with Tune or rllib api?
See original GitHub issueSystem information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04
- Ray installed from (source or binary): Binary
- Ray version: 0.7.3
- Python version: 3.7
- Exact command to reproduce:
Describe the problem
I created an Atari Net with TFModelV2. Is it possible to reinitialize only specific layers (e.g. layer_out and value_out) after restoring from a checkpoint?
I tried to get tf.variable with trainer.get_policy().model.variables()
and assign a new tf.variable to it. But an error message showed up said ValueError: Tensor("random_uniform:0", shape=(8, 8, 1, 32), dtype=float32) must be from the same graph as Tensor("default_policy/conv2d/kernel:0", shape=(), dtype=resource).
Since transfer learning is a common trick, I hope this issue post could help those people have the same problem.
Source code / logs
class KerasAtariNet(TFModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super(KerasAtariNet, self).__init__(obs_space, action_space, num_outputs, model_config, name)
self.inputs = tf.keras.layers.Input(shape=obs_space.shape, name="observations")
conv1 = tf.keras.layers.Conv2D(filters=32, kernel_size=8, strides=4, activation='relu')(self.inputs)
conv2 = tf.keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, activation='relu')(conv1)
conv3 = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, activation='relu')(conv2)
conv_flatten = tf.keras.layers.Flatten()(conv3)
state = tf.keras.layers.Dense(512, activation='relu')(conv_flatten)
layer_out = tf.keras.layers.Dense(num_outputs, name="act_output")(state)
value_out = tf.keras.layers.Dense(1, name="value_output")(state)
self.base_model = tf.keras.Model(self.inputs, [layer_out, value_out])
self.register_variables(self.base_model.variables)
def forward(self, input_dict, state, seq_lens):
model_out, self._value_out = self.base_model(input_dict["obs"])
return model_out, state
def value_function(self):
return tf.reshape(self._value_out, [-1])
trainer = ppo.PPOTrainer(config=config, env="my_env")
# trainer.restore(chpt_path)
# print(trainer.get_policy().model.variables()[0].shape)
print(trainer.get_policy().model.variables()[0].eval(session=trainer.get_policy()._sess))
trainer.get_policy().model.variables()[0].assign(tf.initializers.glorot_uniform()(shape=trainer.get_policy().model.variables()[0].shape.as_list()))
# print(trainer.get_policy().model.variables()[0].numpy())
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (2 by maintainers)
Top Results From Across the Web
How To Customize Policies — Ray 2.2.0
To simplify the definition of policies, RLlib includes Tensorflow and PyTorch-specific templates. You can also write your own from scratch.
Read more >Getting Started with RLlib — Ray 2.2.0 - the Ray documentation
In this guide, we will first walk you through running your first experiments with the RLlib CLI, and then discuss our Python API...
Read more >A Guide To Callbacks & Metrics in Tune — Ray 2.2.0
Ray Tune supports callbacks that are called during various times of the training process. Callbacks can be passed as a parameter to air....
Read more >Algorithms — Ray 2.2.0 - the Ray documentation
To visualize learning, RLlib Dreamer's imagined trajectories are logged as gifs in TensorBoard. Examples of such can be seen here. Tuned examples: Deepmind ......
Read more >Ray Tune FAQ — Ray 2.2.0 - the Ray documentation
In supervised learning, we train a model with labeled data so the model can ... for transferring files between nodes and cloud storage...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I figured out how to solve the problem. The session from
trainer.get_policy()
must be used for evaluating the tensors.The problem may be caused by using the different session from the one from
trainer.get_policy()
, so the values are initialized every time.The problem has been solved. Close this issue.