question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Easier syntax for `eqx.tree_at()`

See original GitHub issue

Hey Patrick,

I’ve been developing dLux with @benjaminpope, who you recently met up with for a coffee in London! I believe you discussed a gentler syntax for parameter updates, and a more general update function to build native interfacing for equinox with probabilistic programming languages to run MCMC algorithms. I think I have found some nice solutions that might be worth integrating with equinox.

The general goal has been to be able to pass in a list/tuple of string/index based paths in order to specify where in the arbitrary pytree structure one would like to update (or apply functions!) to parameters, along with a list/tuple of new values/functions to apply to those parameters. Here are the functions and a minimal working example demonstrating the usage:

Update functions:


import jax.numpy as np
import equinox as eqx
import typing

def get_leaf(pytree, path):
    """
    Recuses down the path of the pytree
    """
    key = path[0]
    pytree = pytree.__dict__[key] if isinstance(pytree, eqx.Module) else \
             pytree[key]
    
    # Return param if at the end of path, else recurse
    return pytree if len(path) == 1 else get_leaf(pytree, path[1:])

def update_pytree(pytree, paths, values):
    """
    Updates the `pytree` leaves specificied by params_paths with values
    """
    # Returns a tuple of leaves specified by paths
    get_leaves = lambda pytree : tuple([get_leaf(pytree, paths[i]) \
                                        for i in range(len(paths))])
    
    # Updates the leaf if passed a function
    update_leaf = lambda leaf, leaf_update: leaf_update(leaf)\
            if isinstance(leaf_update, typing.Callable) else leaf_update
    
    # Updates the leaves specified by paths
    update_values = tuple([update_leaf(get_leaf(pytree, paths[i]), \
                            values[i]) for i in range(len(paths))])
    
    return eqx.tree_at(get_leaves, pytree, update_values)

Define two classes to interface with

class Foo(eqx.Module):
    param1 : float
    
    def __init__(self, p1):
        self.param1 = p1
        
class Bar(eqx.Module):
    foo : object
    some_dict : dict
    some_list : list
    param1 : float
    
    def __init__(self, foo, some_dict, some_list, param1):
        self.foo = foo
        self.some_dict = some_dict
        self.some_list = some_list
        self.param1 = param1

some_dict = {'entry1': 0., 'entry2': 1., 'foo': Foo(2)}
some_list = [3, 4, Foo(5)]
pytree = Bar(Foo(6), some_dict, some_list, 7)

print(pytree)
>>>Bar(
>>>  foo=Foo(param1=6),
>>>  some_dict={'entry1': 0.0, 'entry2': 1.0, 'foo': Foo(param1=2)},
>>>  some_list=[3, 4, Foo(param1=5)],
>>>  param1=7
>>>)

Define paths to parameters, as well as new values/functions to apply

paths = [['foo', 'param1'], 
         ['some_dict', 'entry1'],
         ['some_dict', 'foo', 'param1'],
         ['some_list', 0],
         ['some_list', 2,'param1']]
values = [lambda x: -x, 10, 11, lambda x: x**2, 14]

new_pytree = update_pytree(pytree, paths, values)

print(new_pytree)
>>> Bar(
>>>  foo=Foo(param1=-6),
>>>  some_dict={'entry1': 10, 'entry2': 1.0, 'foo': Foo(param1=11)},
>>>  some_list=[9, 4, Foo(param1=14)],
>>>  param1=7
>>>)

As you can see from this example it is able to arbitrarily update or apply functions to each leaf specified by the path through arbitrary nesting. I think that this syntax will already help non-developer and less experienced users be able to interface with our software. I am already using it in a Base class that all others inherit from, so it would make sense to me to move this into equinox. The general functionality is more related to pytrees, so having it live in the tree.py could also make sense.

Anyway I can create a PR if you’re interested, I would to be able to contribute to this awesome software! Cheers

ps. I do have some further functionality I use to get a nice interface with probabilistic programming languages like Numpyro, which could be nice for others using equinox to build more physical modelling software. They are essentially just lightweight wrappers but Im happy to share examples of that if you’re interested

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
LouisDesdoigtscommented, Sep 12, 2022

Hey so I’ve been thinking some more about this, and I think that given the wider use-cases that I hadn’t considered, the current tree_at syntax is likely optimal. I think that there could be value in adding a reduced functionality update_pytree or update_params function to equinox. I know that for our use-case this would be valuable both for end-users, and make interfacing with packages with like Numpyro a bit easier.

This should only require adding the __getitem__ method to eqx.Module, which simply calls the __dict__ method and indexes it:

  def __getitem__(self, index):
      return self.__dict__[index]

You mentioned in your original reply that you wanted equinox modules to be general pytrees, I think that allowing them to indexable would actually help serve this goal. With this addition, if we ignore applying functions to leaves, then a simple parameter update function would look like this:

import jax.numpy as np
import equinox as eqx

def get_leaf(pytree, path):
    """
    Recuses down the path of the pytree
    """
    # Return param if at the end of path, else recurse
    pytree = pytree[path[0]]
    return pytree if len(path) == 1 else get_leaf(pytree, path[1:])

def update_pytree(pytree, paths, values):
    """
    Updates the `pytree` leaves specificied by params_paths with values
    """
    # Returns a tuple of leaves specified by paths
    get_leaves = lambda pytree : tuple([get_leaf(pytree, paths[i]) \
                                        for i in range(len(paths))])
    
    return eqx.tree_at(get_leaves, pytree, values)

Which using the same example as above would be interfaced with like this:

class Base(eqx.Module):
    def __getitem__(self, index):
        return self.__dict__[index]
    
class Foo(Base):
    param1 : float
    
    def __init__(self, p1):
        self.param1 = p1
        
class Bar(Base):
    foo : object
    some_dict : dict
    some_list : list
    param1 : float
    
    def __init__(self, foo, some_dict, some_list, param1):
        self.foo = foo
        self.some_dict = some_dict
        self.some_list = some_list
        self.param1 = param1

some_dict = {'entry1': 0., 'entry2': 1., 'foo': Foo(2)}
some_list = [3, 4, Foo(5)]
pytree = Bar(Foo(6), some_dict, some_list, 7)

print(pytree)
>>>Bar(
>>>  foo=Foo(param1=6),
>>>  some_dict={'entry1': 0.0, 'entry2': 1.0, 'foo': Foo(param1=2)},
>>>  some_list=[3, 4, Foo(param1=5)],
>>>  param1=7
>>>)
paths = [['foo', 'param1'], 
         ['some_dict', 'entry1'],
         ['some_dict', 'foo', 'param1'],
         ['some_list', 0],
         ['some_list', 2,'param1']]
values = [9, 10, 11, 12, 14]

new_pytree = update_pytree(pytree, paths, values)
print(new_pytree)
>>>Bar(
>>>  foo=Foo(param1=9),
>>>  some_dict={'entry1': 10, 'entry2': 1.0, 'foo': Foo(param1=11)},
>>>  some_list=[12, 4, Foo(param1=14)],
>>>  param1=7
>>>)

This could also be robustly extended to apply functions to leaves by similarly passing in a list/tuples of booleans that either replace or apply the ‘values’ list.

Let me know what you think!

0reactions
LouisDesdoigtscommented, Sep 15, 2022

Yeah no problem, that does make sense. Thanks for the help though!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Preparing a SPSS syntax file
The philosophy and purpose behind editing the SPSS syntax file is to ensure that the end user has access to accurate and searchable...
Read more >
Equinox/p2/Query Language for p2 - Eclipsepedia
Supplying a boolean expression that is used as a callback when some other object iterates over a collection and computes a subset. That,...
Read more >
equinox - PyPI
Equinox brings more power to your model building in JAX. Represent parameterised functions as data and use filtered transformations for powerful ...
Read more >
Workload Automation Programming Language for z/OS ... - IBM
This guide is part of a set of guides that allows you to program many aspects of working with the products in the...
Read more >
SkyCoord — Astropy v5.2
For a complete description of the allowed syntax see the full coordinates ... 2 * u.deg) # Uses defaults for obstime, equinox >>>...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found