question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Allow creating module instances outside hk.transform

See original GitHub issue

This is as much a question as it is a feature request. What is the reasoning for not allowing a module instance from being created (but not used) outside hk.transform? I took a look at hk.Module and ModuleMetaClass but I feared my soul would get harvested by the dark forbidden magic involved before I could identify all the API features it permits.

For example, I would have expected this to be possible:

linear = hk.Linear(10)  # currently not allowed

def forward(x):
  return linear(x)

model = hk.transform(forward)

Concretely, I’m curious to know what would have to be sacrificed (if anything) to support this kind of usage? Is it meant to prevent a module instance from being used in two different functions wrapped by two different hk.transform calls?

I wouldn’t be surprised if I were missing some nasty side effect if you were to allow module creation outside of hk.transform, but, if not, I think it would be more intuitive to allow this kind of usage.

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
awavcommented, Mar 23, 2020

Hello @trevorcai, @tomhennigan. I like a lot out of the box solutions, but I struggle with extending haiku at the moment. I need constrained parameters like variance (only positive) for Gaussian distributions. The parameter can be represented as a composition constraint: unconstrained_parameter -> bijector.forward(parameter), in my code it is a property of the module. A dictionary with a set of parameters contains only unconstrained version, but for tracking and model printing we need constrained values and there is no way to get it because the model instance is hidden in the function.

class Parameter():
  def __init__(self, init_value: float, name: Text):
    super().__init__(name="parameter")
    self._name = name
    self._init = hk.initializer.Constant(jnp.log(init_value))

  def __call__(self):
    return jnp.exp(hk.get_parameter(f"unconstrained_{self._name}", shape=[], init=self._init))

class Model(hk.Module):
  def __init__(self, init_variance: float, name: Text):
    super().__init__(name)
    self._variance = Parameter(init_variance, "variance")

  @property
  def variance(self):
    return self._variance()

  def __call__(self, x: jnp.array) -> jnp.array:
    return jnp.sum(self.variance * x)

As you can see, a variance value in a parameter dictionary will not have much meaning without information about a transformation that a model uses (could be exp, softplus or another positive bijector).

1. One solution could be to return a model with transformed functions.

def forward_fn(x):
  m = Model(0.1)
  hk.link(m)  
  return m(x)

forward = hk.transoform(forward_fn)
model = forward.linked_objects  # get access to read only object

2. Another possible (?) solution could be making hk.transform a context manager

class Holder(hk.ModuleHolder):
  @hk.transform
  def forward(self, x):
    self.model = Model(0.1)
    return self.model(x)

forward = Holder().forward()

PS: for me, it is a very important issue and a deciding factor on how I’m going to use the library.

0reactions
trevorcaicommented, Mar 3, 2020

That’s good to hear - that’s been my experience as well 😃 I’ll leave this issue open to track this FR.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Haiku API reference
Next, define some function that creates and applies modules. We use transform() ... Registers parameters from an inner init function in an outer...
Read more >
How to Create Reusable Infrastructure with Terraform Modules
The easiest way to create a versioned module is to put the code for the module in a separate Git repository and to...
Read more >
Finetuning Transformers with JAX + Haiku
hk.transform turns functions that use these object-oriented, functionally "impure" modules into pure functions that can be used with jax.jit, ...
Read more >
Modules - Beginner JavaScript - Wes Bos
Let's go into our playground and create a folder called modules (it might ... That got out of hand pretty quickly because you...
Read more >
From PyTorch to JAX: towards neural net frameworks that ...
To make sure we're on the same page, let's implement the language ... allows you to do is to take a function like...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found