Recommended way to associate metadata with parameters
See original GitHub issueI have a project where I’m using custom_creator to compute some auxiliary information for each parameter. I’d like to end up with two pytrees with the same structure, params
, and aux
, then do something like jax.tree_map(my_function, params, aux)
.
(It’s not possible to compute the auxiliary info directly from the params
pytree, as it needs some contextual information that’s only available from knowing which modules are being used etc).
My current solution is to use a custom creator which returns a named tuple having both the params and the extra info, and a custom getter which ignores that info:
from collections import namedtuple
import haiku as hk
import jax
import jax.numpy as jnp
ParamAndAux = namedtuple('ParamAndAux', ['param', 'aux'])
def my_creator(next_creator, shape, dtype, init, context):
param = next_creator(shape, dtype, init)
aux = 12345 # Replace this with actually doing something interesting
return ParamAndAux(param=param, aux=aux)
def my_getter(next_getter, value, context):
if isinstance(value, ParamAndAux):
return value.param
return next_getter(value)
def split(params_and_aux):
inner_structure = jax.tree_util.tree_structure((0, 0))
outer_structure = jax.tree_util.tree_structure(
params_and_aux,
is_leaf=lambda n: isinstance(n, ParamAndAux)·
)
return jax.tree_util.tree_transpose(
outer_structure,
inner_structure,
params_and_aux)
def main():
def f(x):
with hk.custom_creator(my_creator), hk.custom_getter(my_getter):
return hk.Linear(17)(x)
model = hk.without_apply_rng(hk.transform(f))
params_and_aux = model.init(jax.random.PRNGKey(0), jnp.zeros(7))
params, aux = split(params_and_aux) # params and aux end up with the same structure as desired
if __name__ == '__main__':
main()
Is there a better way to accomplish this? One issue with this approach is that if any other getters or creators are added after mine, they probably won’t work.
Issue Analytics
- State:
- Created a year ago
- Comments:5
Thanks for the info. I didn’t know about
__jax_array__
before, and haven’t been able to find any documentation about it. Is there somewhere I can read about how JAX makes use of it?Hi @davisyoshida, I can’t think of a better solution, and someone internally is using this exact pattern so it might be useful to know that there is at least one other person who agrees with us.
One enhancement you might consider would be to for your type to implement
__jax_array__
. Then JAX operations would understand how to unpack it and you might be able to avoid the custom getter:Re the order of creators/getters, I understand this might be a source of issues (I suspect just for the getter) but in practice I think it is usually quite easy for folks to re-order getters in their program (and it is not typical to have lots of them) so I would hope this would only be a small amount of friction.