question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[question] Using keras in Custom Policy

See original GitHub issue

I am trying to use keras to define my own custom policy, unfortunately after several hours of trying I couldn’t get it to train on CartPole.

Here is the CustomPolicy example I have modified to work with Cartpole, and this trains properly.

class CustomPolicy(ActorCriticPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **kwargs):
        super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=False)

        with tf.variable_scope("model", reuse=reuse):
            activ = tf.nn.tanh

            extracted_features = tf.layers.flatten(self.processed_obs)

            pi_h = extracted_features
            for i, layer_size in enumerate([64, 64]):
                pi_h = activ(tf.layers.dense(pi_h, layer_size, name='pi_fc' + str(i)))
            pi_latent = pi_h

            vf_h = extracted_features
            for i, layer_size in enumerate([64, 64]):
                vf_h = activ(tf.layers.dense(vf_h, layer_size, name='vf_fc' + str(i)))
            value_fn = tf.layers.dense(vf_h, 1, name='vf')
            vf_latent = vf_h

            self.proba_distribution, self.policy, self.q_value = \
                self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)

        self.value_fn = value_fn
        self.initial_state = None        
        self._setup_init()

Here is the Keras version of my implementation that runs, but does NOT train. (tf.keras.layers vs keras.layers) doesn’t make a difference.

class KerasPolicy(ActorCriticPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **kwargs):
        super(KerasPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=False)

        with tf.variable_scope("model", reuse=reuse):
            flat = tf.keras.layers.Flatten()(self.processed_obs)

            x = tf.keras.layers.Dense(64, activation="tanh", name='pi_fc_0')(flat)
            pi_latent = tf.keras.layers.Dense(64, activation="tanh", name='pi_fc_1')(x)

            x1 = tf.keras.layers.Dense(64, activation="tanh", name='vf_fc_0')(flat)
            vf_latent = tf.keras.layers.Dense(64, activation="tanh", name='vf_fc_1')(x1)

            value_fn = tf.keras.layers.Dense(1, name='vf')(vf_latent)

            self.proba_distribution, self.policy, self.q_value = \
                self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)

        self.value_fn = value_fn
        self.initial_state = None
        self._setup_init()

I tried to ensure both implementations are as close to eachother as possible. Any help at this point would be grately appreciated.

Thank you in advance

Keras version: 2.2.2 Tensorflow version: 1.12.0 Stable Baselines version: 2.4.0a

Attached is the minimal code to reproduce the current issue with tensorboard graphs for comparison. custom_model.py.zip

Issue Analytics

  • State:open
  • Created 5 years ago
  • Comments:10 (2 by maintainers)

github_iconTop GitHub Comments

3reactions
michalgregorcommented, May 14, 2019

Are there any further plans regarding this? Now that we know TF 2.0 is going to drop tf.variable_scope and even handle sessions differently, will everything pretty much have to be rewritten?

2reactions
jckastelcommented, Dec 6, 2019

Would like to add my vote here as well. Will this get fixed at some point, or will we have to wait for the TF2.0 compatible version? Not being able to use predefined keras layers means that a ton of really useful model and layer libraries are unusable with stable-baselines, and that model code will be less future proof and much more difficult to read and maintain. This is a very unfortunate limitation to an otherwise really nice Deep RL library.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Keras FAQ
What's the recommended way to monitor my metrics when training with fit() ? What if I need to customize what fit() does? How...
Read more >
Using custom tensorflow ops in keras - python - Stack Overflow
It is clear for me how to use a custom (lambda) function wrapped in Lambda layer, what I would like to understand is...
Read more >
Keras custom layer using tensorflow function
Depends on what keras you are using. If you are using Keras (not tf Keras), then you have to inherit from Keras Layer....
Read more >
Custom training: walkthrough | TensorFlow Core
Importing data with the TensorFlow Datasets API; Building models and layers with the Keras API. Penguin classification problem. Imagine you are an ornithologist ......
Read more >
Building Custom Policy Classes — Ray 2.2.0
As of Ray >= 1.9, it is no longer recommended to use the build_policy_class() or build_tf_policy() utility functions for creating custom Policy sub-classes....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found