question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Indicating training/testing modes in sonnet callbacks

See original GitHub issue

What is a convenient way of providing boolean training flags, e.g. is_training that indicate, for example, whether to use batch_norm or not when using sonnet callback functions?

Example:

def make_transpose_cnn_model():
    def transpose_convnet1d(inputs):
        inputs = tf.expand_dims(inputs, axis=2)

        outputs = snt.Conv1DTranspose(output_channels=2, kernel_shape=10, stride=1)(inputs)
        outputs = snt.BatchNorm()(outputs, is_training=True) <- want to have this as input
        outputs = tf.nn.relu(outputs)
        outputs = snt.Conv1DTranspose(output_channels=2, kernel_shape=10, stride=1)(outputs)
        outputs = snt.BatchNorm()(outputs, is_training=True) <- want to have this as input
        outputs = tf.nn.relu(outputs)
        outputs = snt.BatchFlatten()(outputs)
        #outputs = tf.nn.dropout(outputs, keep_prob=tf.constant(1.0)) <- want to have this as input
        outputs = snt.Linear(output_size=128)(outputs)

        return outputs

    return transpose_convnet1d`

and

self._network = modules.GraphIndependent(
                edge_model_fn=EncodeProcessDecode.make_mlp_model,
                node_model_fn=EncodeProcessDecode.make_transpose_cnn_model,
                global_model_fn = EncodeProcessDecode.make_mlp_model)

I can’t pass this parameter in the _build() function as shown in the following since the interface of modules.GraphIndipendent won’t allow it:

    def _build(self, input_op, num_processing_steps, is_training=True):
        latent = self._encoder(input_op, is_training)

it yields:

TypeError: _build() got an unexpected keyword argument ‘is_training’

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:8

github_iconTop GitHub Comments

1reaction
malcolmreynoldscommented, Nov 20, 2018

If you have a module that you want to compose out of other modules, and some of the submodules require extra arguments like is_training, the canonical thing to do would be to define a new subclass of AbstractModule, rather than writing a function as you have. In particular, defining a module allows you to explicitly reuse variables by just calling the module twice (potentially with different is_training kwarg values).

Your example would be something like this:

class TransposeCnnModel(snt.AbstractModule):
  def __init__(self, name='transpose_cnn_model'):
    super(TransposeCnnModel, self).__init__(name=name)

  def _build(self, inputs, is_training):
    inputs = tf.expand_dims(inputs, axis=2)

    outputs = snt.Conv1DTranspose(output_channels=2, kernel_shape=10, stride=1)(inputs)
    outputs = snt.BatchNorm()(outputs, is_training=is_training)
    outputs = tf.nn.relu(outputs)
    outputs = snt.Conv1DTranspose(output_channels=2, kernel_shape=10, stride=1)(outputs)
    outputs = snt.BatchNorm()(outputs, is_training=is_training)
    outputs = tf.nn.relu(outputs)
    outputs = snt.BatchFlatten()(outputs)
    keep_prob = 0.7 if is_training else 1.0
    outputs = tf.nn.dropout(outputs, keep_prob=keep_prob)
    outputs = snt.Linear(output_size=128)(outputs)

    return outputs    

You would probably want to make more things configurable as constructor args (e.g. final output size, channels in the conv, dropout probability for training etc) but the above should fit the API that the GraphNet library expects.

As to your second point, I’m not sure exactly what you mean by Sonnet callback functions - could you be more specific?

0reactions
ferreirafabiocommented, Mar 19, 2019

it worked! Ty

Read more comments on GitHub >

github_iconTop Results From Across the Web

QTI v3 Best Practices and Implementation Guide - 1EdTech
This allows learning standards and 1EdTech CASE (Competencies and Academic Standards Exchange) identifiers to be associated with a test or item.
Read more >
Text Generation With LSTM Recurrent Neural Networks in ...
In this post, you will discover how to create a generative model for text, character-by-character using LSTM recurrent neural networks in Python ...
Read more >
Tracking Callbacks - | notebook.community
This module regroups the callbacks that track one of the metrics computed at the end of each epoch to take some decision about...
Read more >
Untitled
Is bulls head staten island safe, Henry boot training dudley, ... Microgeophagus ramirezi balloon, Just dance 4 battle mode, Jonny fairplay fox reality ......
Read more >
ON PROGRAM AURALIZATION A Thesis Submitted to the Faculty of ...
Maldonado for his idea of applying Listen to mutation testing which relates ... Table of means showing e ect of cue type on...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found