Easier syntax for `eqx.tree_at()`
See original GitHub issueHey Patrick,
I’ve been developing dLux with @benjaminpope, who you recently met up with for a coffee in London! I believe you discussed a gentler syntax for parameter updates, and a more general update function to build native interfacing for equinox with probabilistic programming languages to run MCMC algorithms. I think I have found some nice solutions that might be worth integrating with equinox.
The general goal has been to be able to pass in a list/tuple of string/index based paths
in order to specify where in the arbitrary pytree structure one would like to update (or apply functions!) to parameters, along with a list/tuple of new values/functions to apply to those parameters. Here are the functions and a minimal working example demonstrating the usage:
Update functions:
import jax.numpy as np
import equinox as eqx
import typing
def get_leaf(pytree, path):
"""
Recuses down the path of the pytree
"""
key = path[0]
pytree = pytree.__dict__[key] if isinstance(pytree, eqx.Module) else \
pytree[key]
# Return param if at the end of path, else recurse
return pytree if len(path) == 1 else get_leaf(pytree, path[1:])
def update_pytree(pytree, paths, values):
"""
Updates the `pytree` leaves specificied by params_paths with values
"""
# Returns a tuple of leaves specified by paths
get_leaves = lambda pytree : tuple([get_leaf(pytree, paths[i]) \
for i in range(len(paths))])
# Updates the leaf if passed a function
update_leaf = lambda leaf, leaf_update: leaf_update(leaf)\
if isinstance(leaf_update, typing.Callable) else leaf_update
# Updates the leaves specified by paths
update_values = tuple([update_leaf(get_leaf(pytree, paths[i]), \
values[i]) for i in range(len(paths))])
return eqx.tree_at(get_leaves, pytree, update_values)
Define two classes to interface with
class Foo(eqx.Module):
param1 : float
def __init__(self, p1):
self.param1 = p1
class Bar(eqx.Module):
foo : object
some_dict : dict
some_list : list
param1 : float
def __init__(self, foo, some_dict, some_list, param1):
self.foo = foo
self.some_dict = some_dict
self.some_list = some_list
self.param1 = param1
some_dict = {'entry1': 0., 'entry2': 1., 'foo': Foo(2)}
some_list = [3, 4, Foo(5)]
pytree = Bar(Foo(6), some_dict, some_list, 7)
print(pytree)
>>>Bar(
>>> foo=Foo(param1=6),
>>> some_dict={'entry1': 0.0, 'entry2': 1.0, 'foo': Foo(param1=2)},
>>> some_list=[3, 4, Foo(param1=5)],
>>> param1=7
>>>)
Define paths to parameters, as well as new values/functions to apply
paths = [['foo', 'param1'],
['some_dict', 'entry1'],
['some_dict', 'foo', 'param1'],
['some_list', 0],
['some_list', 2,'param1']]
values = [lambda x: -x, 10, 11, lambda x: x**2, 14]
new_pytree = update_pytree(pytree, paths, values)
print(new_pytree)
>>> Bar(
>>> foo=Foo(param1=-6),
>>> some_dict={'entry1': 10, 'entry2': 1.0, 'foo': Foo(param1=11)},
>>> some_list=[9, 4, Foo(param1=14)],
>>> param1=7
>>>)
As you can see from this example it is able to arbitrarily update or apply functions to each leaf specified by the path through arbitrary nesting. I think that this syntax will already help non-developer and less experienced users be able to interface with our software. I am already using it in a Base class that all others inherit from, so it would make sense to me to move this into equinox. The general functionality is more related to pytrees, so having it live in the tree.py
could also make sense.
Anyway I can create a PR if you’re interested, I would to be able to contribute to this awesome software! Cheers
ps. I do have some further functionality I use to get a nice interface with probabilistic programming languages like Numpyro, which could be nice for others using equinox to build more physical modelling software. They are essentially just lightweight wrappers but Im happy to share examples of that if you’re interested
Issue Analytics
- State:
- Created a year ago
- Comments:6 (2 by maintainers)
Top GitHub Comments
Hey so I’ve been thinking some more about this, and I think that given the wider use-cases that I hadn’t considered, the current
tree_at
syntax is likely optimal. I think that there could be value in adding a reduced functionalityupdate_pytree
orupdate_params
function to equinox. I know that for our use-case this would be valuable both for end-users, and make interfacing with packages with like Numpyro a bit easier.This should only require adding the
__getitem__
method toeqx.Module
, which simply calls the__dict__
method and indexes it:You mentioned in your original reply that you wanted equinox modules to be general pytrees, I think that allowing them to indexable would actually help serve this goal. With this addition, if we ignore applying functions to leaves, then a simple parameter update function would look like this:
Which using the same example as above would be interfaced with like this:
This could also be robustly extended to apply functions to leaves by similarly passing in a list/tuples of booleans that either replace or apply the ‘values’ list.
Let me know what you think!
Yeah no problem, that does make sense. Thanks for the help though!