question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Correct way to transform and init a `hk.Module` with non-default parameter?

See original GitHub issue

Hey all!

I’m trying to run a linear regression example and I’ve got the following

import jax.numpy as jnp
from sklearn.datasets import load_boston
import haiku as hk
import optax
import jax


X, y = load_boston(return_X_y=True)
train_X = jnp.asarray(X.tolist())
train_y = jnp.asarray(y.tolist())
    
class Model(hk.Module):
    def __init__(self, input_dims):
        super().__init__()
        self.input_dims = input_dims
    
    def __call__(self, X: jnp.ndarray) -> jnp.ndarray:
        l1 = hk.Linear(self.input_dims)
        return l1(X)
    
model = hk.transform(lambda x: Model()(x))  # <-- where I would specify the model shape if at all? 

So I’m running into an issue where I’m not able to specify the model shape. If I do not specify it as in the above, I get the error of

__init__() missing 1 required positional argument: 'input_dims'

but if I do specify the shape via

model = hk.transform(lambda x: Model(train_X.shape[1])(x))

I get Argument '<function without_state.<locals>.init_fn at 0x7f1e5c616430>' of type <class 'function'> is not a valid JAX type.


What is the recommended way of addressing this? I’m reading through hk.transform but I’m not sure. Looking at the code examples, there are __init__ functions without default args so I know it’s possible.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6

github_iconTop GitHub Comments

1reaction
tomhennigancommented, Sep 19, 2021

Is there a function signature somewhere for f.init?

Given a function f(*a, **k) -> out, hk.transform(f) gives you back a pair of functions: f.init(rng, *a, **k) -> params and f.apply(params, rng, *a, **k) -> out.

If you wanted to get a deeper understanding of what happens inside transform then take a look here: https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html

By the way, there is another common option, which is to transform a method on a regular object:

class Model:
  def __init__(self, input_dims):
    self.input_dims = input_dims
    self.init, self.apply = hk.transform(self._forward)

  def _forward(self, x):
    m = hk.Linear(self.input_dims)
    return m(x)

m = Model(input_dims=10)
params = m.init(rng, x)
out = m.apply(params, rng, x)

I’ve attached my colab link here. Thank you for your help!

So the key issue here is that jax.grad requires all arguments to be JAX Arrays. You can reproduce your error with the following minimal code:

>>> jax.grad(lambda x: x)(lambda: None)
...
TypeError: Argument '<function <lambda> at 0x7f2116434290>' of type <class 'function'> is not a valid JAX type.

I’ve modified your colab to use the pattern described above (have a regular python object holding the haiku and jax transformed methods) and the loss seems to go in the right direction: https://colab.research.google.com/gist/tomhennigan/456984830510eded8f1675476bf1ff8f/haiku_ianqs_dbg.ipynb

0reactions
IanQScommented, Sep 20, 2021

I don’t think this is true, you can have an arbitrarily complex driving of modules inside your transformed function. There is no requirement in Haiku for there to be a single top level module inside the transform, or for modules to be called in a particular order. For example you might want to add a residual inside your sequential stack:

Ahh, sorry, I was conflating hk.Sequential with defining the computation graph in a function. Yeah, that makes sense as defining the graph in a function is probably another way of expressing it compared to a module. It probably all thunks down to the same thing

decided to use it you can be confident that it should work well (if your needs are similar to ours).

Thanks! I’m looking at it for research, and also in terms of spinning up a startup on it. I’m mostly evaluating which library makes the most sense: has the most clear documentation; is better tested; has fewer / more explicitly mentioned sharp edges and so forth.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Haiku API reference
We use transform() to transform that function into a pair of functions that allow us to lift all the parameters out of the...
Read more >
Allow creating module instances outside hk.transform #16
transform ), accessible via self.name or self.module_name . These names route parameters & state into the right place for hk.get_parameter calls ...
Read more >
Chapter 4. Text Vectorization and Transformation Pipelines
In this chapter, we will demonstrate how to use the vectorization process to combine linguistic techniques from NLTK with machine learning techniques in...
Read more >
SyntaxError: non-default argument follows default argument
The correct order of defining parameter in function are: positional parameter or non-default parameter i.e (a, b, c); keyword parameter or ...
Read more >
Build a Transformer in JAX from scratch - Towards Data Science
In this tutorial, we will explore how to develop a Neural Network (NN) with ... a subclass of hk.Module . This means that...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found