question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Recommended way to associate metadata with parameters

See original GitHub issue

I have a project where I’m using custom_creator to compute some auxiliary information for each parameter. I’d like to end up with two pytrees with the same structure, params, and aux, then do something like jax.tree_map(my_function, params, aux). (It’s not possible to compute the auxiliary info directly from the params pytree, as it needs some contextual information that’s only available from knowing which modules are being used etc).

My current solution is to use a custom creator which returns a named tuple having both the params and the extra info, and a custom getter which ignores that info:

from collections import namedtuple    
import haiku as hk    
import jax    
import jax.numpy as jnp    
    
ParamAndAux = namedtuple('ParamAndAux', ['param', 'aux'])    
    
def my_creator(next_creator, shape, dtype, init, context):    
    param = next_creator(shape, dtype, init)    
    aux = 12345 # Replace this with actually doing something interesting    
    return ParamAndAux(param=param, aux=aux)    
    
def my_getter(next_getter, value, context):    
    if isinstance(value, ParamAndAux):    
        return value.param    
    return next_getter(value)    
    
def split(params_and_aux):    
    inner_structure = jax.tree_util.tree_structure((0, 0))    
    outer_structure = jax.tree_util.tree_structure(    
        params_and_aux,    
        is_leaf=lambda n: isinstance(n, ParamAndAux)·    
    )    
    return jax.tree_util.tree_transpose(    
        outer_structure,    
        inner_structure,    
        params_and_aux)    
    
    
def main():    
    def f(x):    
        with hk.custom_creator(my_creator), hk.custom_getter(my_getter):    
            return hk.Linear(17)(x)    
    
    model = hk.without_apply_rng(hk.transform(f))    
    params_and_aux = model.init(jax.random.PRNGKey(0), jnp.zeros(7))    
    
    params, aux = split(params_and_aux)   # params and aux end up with the same structure as desired

if __name__ == '__main__':    
    main()

Is there a better way to accomplish this? One issue with this approach is that if any other getters or creators are added after mine, they probably won’t work.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5

github_iconTop GitHub Comments

1reaction
davisyoshidacommented, Apr 20, 2022

Thanks for the info. I didn’t know about __jax_array__ before, and haven’t been able to find any documentation about it. Is there somewhere I can read about how JAX makes use of it?

1reaction
tomhennigancommented, Apr 14, 2022

Hi @davisyoshida, I can’t think of a better solution, and someone internally is using this exact pattern so it might be useful to know that there is at least one other person who agrees with us.

One enhancement you might consider would be to for your type to implement __jax_array__. Then JAX operations would understand how to unpack it and you might be able to avoid the custom getter:

import chex
import jax.numpy as jnp

@chex.dataclass
class Box:
  value: jnp.ndarray

  def __jax_array__(self):
    return self.value

a = Box(value=jnp.ones([]))
x = jnp.ones([])
a + x  # works

Re the order of creators/getters, I understand this might be a source of issues (I suspect just for the getter) but in practice I think it is usually quite easy for folks to re-order getters in their program (and it is not typical to have lots of them) so I would hope this would only be a small amount of friction.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Adding Metadata – Boston University Data Services
There are only a few general guidelines for adding metadata to a data set: ensuring your data is citable and providing documentation. Make...
Read more >
Adding metadata information in Parameters view - Cloud - 8.0
The Parameters tab lists the context properties defined in the Contexts view.
Read more >
Top five metadata management best practices - Collibra
Understand the best practices of metadata management to ensure continued value to your organization. Learn more.
Read more >
Working with object metadata - Amazon Simple Storage Service
Name Description Can user modify the value? Date Current date and time. No Content‑Disposition Object presentational information. Yes Content‑Length Object size in bytes. No
Read more >
Create a Managed Metadata column - Microsoft Support
A Managed Metadata column is a new column type that can be added to lists, libraries, or content types to enable site users...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found