Correct way to transform and init a `hk.Module` with non-default parameter?
See original GitHub issueHey all!
I’m trying to run a linear regression example and I’ve got the following
import jax.numpy as jnp
from sklearn.datasets import load_boston
import haiku as hk
import optax
import jax
X, y = load_boston(return_X_y=True)
train_X = jnp.asarray(X.tolist())
train_y = jnp.asarray(y.tolist())
class Model(hk.Module):
def __init__(self, input_dims):
super().__init__()
self.input_dims = input_dims
def __call__(self, X: jnp.ndarray) -> jnp.ndarray:
l1 = hk.Linear(self.input_dims)
return l1(X)
model = hk.transform(lambda x: Model()(x)) # <-- where I would specify the model shape if at all?
So I’m running into an issue where I’m not able to specify the model shape. If I do not specify it as in the above, I get the error of
__init__() missing 1 required positional argument: 'input_dims'
but if I do specify the shape via
model = hk.transform(lambda x: Model(train_X.shape[1])(x))
I get Argument '<function without_state.<locals>.init_fn at 0x7f1e5c616430>' of type <class 'function'> is not a valid JAX type.
What is the recommended way of addressing this? I’m reading through hk.transform but I’m not sure. Looking at the code examples, there are __init__
functions without default args so I know it’s possible.
Issue Analytics
- State:
- Created 2 years ago
- Comments:6
Top Results From Across the Web
Haiku API reference
We use transform() to transform that function into a pair of functions that allow us to lift all the parameters out of the...
Read more >Allow creating module instances outside hk.transform #16
transform ), accessible via self.name or self.module_name . These names route parameters & state into the right place for hk.get_parameter calls ...
Read more >Chapter 4. Text Vectorization and Transformation Pipelines
In this chapter, we will demonstrate how to use the vectorization process to combine linguistic techniques from NLTK with machine learning techniques in...
Read more >SyntaxError: non-default argument follows default argument
The correct order of defining parameter in function are: positional parameter or non-default parameter i.e (a, b, c); keyword parameter or ...
Read more >Build a Transformer in JAX from scratch - Towards Data Science
In this tutorial, we will explore how to develop a Neural Network (NN) with ... a subclass of hk.Module . This means that...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Given a function
f(*a, **k) -> out
,hk.transform(f)
gives you back a pair of functions:f.init(rng, *a, **k) -> params
andf.apply(params, rng, *a, **k) -> out
.If you wanted to get a deeper understanding of what happens inside transform then take a look here: https://dm-haiku.readthedocs.io/en/latest/notebooks/build_your_own_haiku.html
By the way, there is another common option, which is to transform a method on a regular object:
So the key issue here is that
jax.grad
requires all arguments to be JAX Arrays. You can reproduce your error with the following minimal code:I’ve modified your colab to use the pattern described above (have a regular python object holding the haiku and jax transformed methods) and the loss seems to go in the right direction: https://colab.research.google.com/gist/tomhennigan/456984830510eded8f1675476bf1ff8f/haiku_ianqs_dbg.ipynb
Ahh, sorry, I was conflating
hk.Sequential
with defining the computation graph in a function. Yeah, that makes sense as defining the graph in a function is probably another way of expressing it compared to a module. It probably all thunks down to the same thingThanks! I’m looking at it for research, and also in terms of spinning up a startup on it. I’m mostly evaluating which library makes the most sense: has the most clear documentation; is better tested; has fewer / more explicitly mentioned sharp edges and so forth.