[rllib] Weight sharing example does not seem to actually share weights
See original GitHub issueWhat is the problem?
Versions: Python 3.6 | tf2.1 | rllib1.0.0
Weight sharing in tensorflow as implemented in rllib/examples/models/shared_weights_model.py creates two different sets of weights when it is supposed to reuse part of them.
Reproduction (REQUIRED)
import gym
from ray.rllib.examples.models.shared_weights_model import SharedWeightsModel1, SharedWeightsModel2
observation_space = gym.spaces.Box(low=-1, high=1, shape=(10,),)
acition_space = gym.spaces.Discrete(3)
m1 = SharedWeightsModel1(
observation_space=observation_space,
action_space=acition_space,
num_outputs=3,
model_config={},
name="m1"
)
m2 = SharedWeightsModel2(
observation_space=observation_space,
action_space=acition_space,
num_outputs=3,
model_config={},
name="m2"
)
v1 = m1.trainable_variables()
v2 = m2.trainable_variables()
for i in range(len(v1)):
print(v1[i].name,v2[i].name)
Output
fc1/kernel:0 fc1_1/kernel:0
fc1/bias:0 fc1_1/bias:0
fc_out/kernel:0 fc_out_1/kernel:0
fc_out/bias:0 fc_out_1/bias:0
value_out/kernel:0 value_out_1/kernel:0
value_out/bias:0 value_out_1/bias:0
Weight for fc_1 should eb the same but apparently they are not. It seems that tf.keras.layer is not affected by the scope and the solution could be to create the layer outside and use it in both classes. This script seems to work:
import gym
import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
shared_last_layer = tf.keras.layers.Dense(units=64, activation=tf.nn.relu, name="fc1")
class SharedWeightsModel1(TFModelV2):
"""Example of weight sharing between two different TFModelV2s.
Here, we share the variables defined in the 'shared' variable scope
by entering it explicitly with tf1.AUTO_REUSE. This creates the
variables for the 'fc1' layer in a global scope called 'shared'
(outside of the Policy's normal variable scope).
"""
def __init__(self, observation_space, action_space, num_outputs,
model_config, name):
super().__init__(observation_space, action_space, num_outputs,
model_config, name)
inputs = tf.keras.layers.Input(observation_space.shape)
last_layer = shared_last_layer(inputs)
output = tf.keras.layers.Dense(
units=num_outputs, activation=None, name="fc_out")(last_layer)
vf = tf.keras.layers.Dense(
units=1, activation=None, name="value_out")(last_layer)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
self.register_variables(self.base_model.variables)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
out, self._value_out = self.base_model(input_dict["obs"])
return out, []
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
class SharedWeightsModel2(TFModelV2):
"""The "other" TFModelV2 using the same shared space as the one above."""
def __init__(self, observation_space, action_space, num_outputs,
model_config, name):
super().__init__(observation_space, action_space, num_outputs,
model_config, name)
inputs = tf.keras.layers.Input(observation_space.shape)
last_layer = shared_last_layer(inputs)
output = tf.keras.layers.Dense(
units=num_outputs, activation=None, name="fc_out")(last_layer)
vf = tf.keras.layers.Dense(
units=1, activation=None, name="value_out")(last_layer)
self.base_model = tf.keras.models.Model(inputs, [output, vf])
self.register_variables(self.base_model.variables)
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
out, self._value_out = self.base_model(input_dict["obs"])
return out, []
@override(ModelV2)
def value_function(self):
return tf.reshape(self._value_out, [-1])
observation_space = gym.spaces.Box(low=-1, high=1, shape=(10,),)
acition_space = gym.spaces.Discrete(3)
m1 = SharedWeightsModel1(
observation_space=observation_space,
action_space=acition_space,
num_outputs=3,
model_config={},
name="m1"
)
m2 = SharedWeightsModel2(
observation_space=observation_space,
action_space=acition_space,
num_outputs=3,
model_config={},
name="m2"
)
v1 = m1.trainable_variables()
v2 = m2.trainable_variables()
for i in range(len(v1)):
print(v1[i].name,v2[i].name)
Which output is:
fc1/kernel:0 fc1/kernel:0
fc1/bias:0 fc1/bias:0
fc_out/kernel:0 fc_out_1/kernel:0
fc_out/bias:0 fc_out_1/bias:0
value_out/kernel:0 value_out_1/kernel:0
value_out/bias:0 value_out_1/bias:0
And we see that fc1 is well shared.
- [yes ] I have verified my script runs in a clean environment and reproduces the issue.
- [ yes] I have verified the issue also occurs with the latest wheels.
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (5 by maintainers)
Top Results From Across the Web
Algorithms — Ray 2.2.0 - the Ray documentation
Defines a configuration class from which a PPO Algorithm can be built. Example. >>> from ray.rllib.algorithms.ppo import PPOConfig ...
Read more >Distributed Deep RL on Azure ML using Ray's RLLIB and ...
This video explains the pain points associated with single-machine training for RL environments. It contains a detailed walkthrough of an ...
Read more >Action Masking with RLlib - Towards Data Science
Thankfully, we can use action masking — a simple technique that sets ... the item weights and values, and has the current weight...
Read more >nas-bench-1shot1: benchmarking and dissecting one-shot ...
Here, the weights of the opera- tions in each architecture are shared in a supermodel (the so-called one-shot model or convolutional.
Read more >Common causes of nans during training of neural networks
I came across this phenomenon several times. Here are my observations: Gradient blow up. Reason: large gradients throw the learning process ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Great! Thanks for taking the time to look at this!
Np, 😃