Support for custom pytrees
See original GitHub issueHello, haiku
team! Thanks a lot for making awesome haiku
.
I’m interested in sequential probabilistic models. Normally, parameters of probabilistic models are constrained. A simple example would be variance. It can only be positive. I gave an example and explanation of the constrained parameters in https://github.com/deepmind/dm-haiku/issues/16#issuecomment-602087358. The pytrees ideally fits into the described use case. The user can create its own differentiable “vectors” and I would expect haiku
to support these custom structures out of the box. This would allow a user to get back actual structures from transformed functions for printing, debugging, and plotting purposes (the list can be enlarged with other examples from academic needs). Unfortunately, custom differentiable structures don’t work at the moment.
Failing example
In [58]: class S(hk.Module):
...: def __init__(self, x, y):
...: super().__init__()
...: # These are parameters:
...: self.x = x
...: self.y = y
...: def __repr__(self):
...: return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
...: def S_flatten(v):
...: children = (v.x, v.y)
...: aux_data = None
...: return (children, aux_data)
...: def S_unflatten(aux_data, children):
...: return S(*children)
...: register_pytree_node(S, S_flatten, S_unflatten)
...:
...:
...: def function(s):
...: return np.sqrt(s.x**2 * s.y**2)
...:
...: def loss(x):
...: s = S(1.0, 2.0)
...: a = hk.get_parameter("free_parameter", shape=[], dtype=jnp.float32, init=jnp.zeros)
...: return jnp.sum([function(s) * a * x])
...:
...: x = jnp.array([2.0])
...: forward = hk.transform(loss)
...: key = jax.random.PRNGKey(42)
...: params = forward.init(key, x)
In [59]: params
Out[59]:
frozendict({
'~': frozendict({'free_parameter': DeviceArray(0., dtype=float32)}),
})
Thanks
Issue Analytics
- State:
- Created 3 years ago
- Comments:10
Support for extracting module info in a creator has landed 😄 Here’s an example colab using it to extract all info to a dict outside the function: https://colab.research.google.com/drive/1tt9ifYFsxvSSXaFAz_Oq59Im8QY4S16o
Using it inside a transformed function is documented here: https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.experimental.custom_creator
@tomhennigan, I found out that
flax
has a support fordataclasses
and it has all what I needed (a big part of it). I haven’t tried it withhaiku
, but I believe it should work withhaiku
out of the box. JAX must work withdataclass
implicitly, but looks like it cannot, withoutflax
at least. Do you have plans for doing a similar thing?