question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Support for custom pytrees

See original GitHub issue

Hello, haiku team! Thanks a lot for making awesome haiku.

I’m interested in sequential probabilistic models. Normally, parameters of probabilistic models are constrained. A simple example would be variance. It can only be positive. I gave an example and explanation of the constrained parameters in https://github.com/deepmind/dm-haiku/issues/16#issuecomment-602087358. The pytrees ideally fits into the described use case. The user can create its own differentiable “vectors” and I would expect haiku to support these custom structures out of the box. This would allow a user to get back actual structures from transformed functions for printing, debugging, and plotting purposes (the list can be enlarged with other examples from academic needs). Unfortunately, custom differentiable structures don’t work at the moment.

Failing example

In [58]: class S(hk.Module):
    ...:   def __init__(self, x, y):
    ...:     super().__init__()
    ...:     # These are parameters:
    ...:     self.x = x
    ...:     self.y = y
    ...:   def __repr__(self):
    ...:     return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
    ...: def S_flatten(v):
    ...:   children = (v.x, v.y)
    ...:   aux_data = None
    ...:   return (children, aux_data)
    ...: def S_unflatten(aux_data, children):
    ...:   return S(*children)
    ...: register_pytree_node(S, S_flatten, S_unflatten)
    ...:
    ...:
    ...: def function(s):
    ...:   return np.sqrt(s.x**2 * s.y**2)
    ...:
    ...: def loss(x):
    ...:   s = S(1.0, 2.0)
    ...:   a = hk.get_parameter("free_parameter", shape=[], dtype=jnp.float32, init=jnp.zeros)
    ...:   return jnp.sum([function(s) * a * x])
    ...:
    ...: x = jnp.array([2.0])
    ...: forward = hk.transform(loss)
    ...: key = jax.random.PRNGKey(42)
    ...: params = forward.init(key, x)
In [59]: params
Out[59]:
frozendict({
  '~': frozendict({'free_parameter': DeviceArray(0., dtype=float32)}),
})

Thanks

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:10

github_iconTop GitHub Comments

4reactions
tomhennigancommented, Apr 26, 2020

Support for extracting module info in a creator has landed 😄 Here’s an example colab using it to extract all info to a dict outside the function: https://colab.research.google.com/drive/1tt9ifYFsxvSSXaFAz_Oq59Im8QY4S16o

Using it inside a transformed function is documented here: https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.experimental.custom_creator

1reaction
awavcommented, May 4, 2020

@tomhennigan, I found out that flax has a support for dataclasses and it has all what I needed (a big part of it). I haven’t tried it with haiku, but I believe it should work with haiku out of the box. JAX must work with dataclass implicitly, but looks like it cannot, without flax at least. Do you have plans for doing a similar thing?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Working with Pytrees - JAX documentation - Read the Docs
JAX has built-in support for such objects, both in its library functions as well as through the use of functions from jax.tree_utils (with...
Read more >
Treeo
Treeo. A small library for creating and manipulating custom JAX Pytree classes. Light-weight: has no dependencies other than jax .
Read more >
TF_JAX_tutorials - Part 10 (Pytrees in JAX) - Kaggle
A list as a pytree example_1 = [1, 2, 3] # As in normal Python code, ... JAX also lets you register custom...
Read more >
A Pytree-based Module system for Deep Learning in JAX
Pytree -based: Modules are registered as JAX PyTrees, ... quite a lot of extra notions, like custom notions of how to keep track...
Read more >
All of Equinox - Patrick Kidger
(Because it uses JAX abstractions like PyTrees, it doesn't need any real complexity.) ... But Equinox supports using arbitrary Python objects too!
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found