RFC: Elegy/Treex Ecosystem Next Versions
See original GitHub issueHere are some ideas for the Treeo, Treex, and Elegy libraries which hopefully add some quality-of-life improvements so they can stand the test of time a bit better.
Immutability
Treeo/Treex has adopted a mutable/stateful design in favor of simplicity. While careful propagation of the mutated state inside jitted functions guarantees an overall immutable behaviour thanks to pytree cloning, there are some downsides:
- Asymmetry between traced (jited, vmaped, etc) and non-traced functions, stateful operations could mutate the original object in non-traced functions while this wouldn’t happen in traced functions.
- There are no hints for the user that state needs to be propagated.
Proposal
Add an Immutable
mixin in Treeo and have Treex use it for its base Treex
class, this work already started in cgarciae/treeo#13 and will do the following:
- Enforces immutability via
__setattr__
by raising aRuntimeError
when a field being updated. - Exposes a
replace(**kwargs) -> Tree
methods that let you replace the values for desired fields but returns a new object. - Exposes a
mutable(method="__call__")(*args, **kwargs) -> (output, Tree)
method that lets call another method that includes mutable operations in an immutable fashion.
Creating an immutable Tree via the Immutable mixing would look like this:
import treeo as to
class MyTree(to.Tree, to.Immutable):
...
Additionally Treeo could also expose an ImmutableTree
class so if users are not comfortable with mixins they could do it like this:
class MyTree(to.ImmutableTree):
...
Examples
Field updates
Mutably you would update a field like this:
tree.n = 10
Whereas in the immutable version you use replace
and get a new tree
:
tree = tree.replace(n=10)
Stateful Methods
Now if your Tree class had some stateful method such as:
def acc_sum(self, x):
self.n += x
return self.n
Mutably you could simply use it like this:
output = tree.acc_sum(x)
Now if your tree is immutable you would use mutable
which let you run this method but the update are capture in a new instance which is returned along with the output of the method:
output, tree = tree.mutable(method="acc_sum")(x)
Alternatively you could also use it as a function transformation via treeo.mutable
like this:
output, tree = treeo.mutable(tree.acc_sum)(tree, x)
Random State
Treex’s Module
s currently treat random state simply as internal state, because its hidden its actually a bit more difficult to reason about and can cause a variety of issues such as:
- Changing state when you don’t want it to do so
- Freezing state by accident if you forget to propagate updates
Proposal
Remove the Rng
kind and create an apply
method similar (but simpler) to Flax’s apply
with the following signature:
def apply(
self,
key: Optional[PRNGKey],
*args,
method="__call__",
mutable: bool = True,
**kwargs
) -> (Output, Treex)
As you see this method accepts an optional key
as its first argument and then just the *args
and **kwargs
for the function. Regular usage would change from:
y = model(x)
to
y, model = model.apply(key, x)
However, if the module is stateless and doesn’t require RNG state you can still call the module directly.
Losses and Metrics
Current Losses and Metrics in Treex (which actually come from Elegy) are great! Since losses and metrics are mostly just Pytree with simple state, it would be nice if one could extract them into their own library and with some minor refactoring build a framework independent losses and metrics library that could be used by anyone in the JAX ecosystem. We could eventually create a library called jax_tools
(or something) that contains utilities such as a Loss
and Metric
interface + implementations of common losses and metrics, and maybe other utilities.
As for the Metric API, I was recently looking a the clu from the Flax team and found some nice ideas that could make the implementation of distributed code simpler.
Proposal
Make Metic
immutable and update its API to:
class Metric(ABC):
@abstractmethod
def update(self: M, **kwargs) -> M:
...
@abstractmethod
def reset(self: M) -> M:
...
@abstractmethod
def compute(self) -> tp.Any:
...
@abstractmethod
def aggregate(self: M) -> M:
...
# could even default to:
# jax.tree_map(lambda x: jnp.sum(x, axis=0), self)
@abstractmethod
def merge(self: M, other: M) -> M:
stacked = jax.tree_map(lambda *xs: jnp.stack(xs), self, other)
return stacked.aggregate()
def batch_updates(self: M, **kwargs) -> M:
return self.reset().update(**kwargs)
Very similar to the Keras API with the exception of the aggregate
method which is incredibly useful when syncing devices on a distributed setup.
Elegy Model
Nothing concrete for the moment, but looking thinking Pytorch Lightning-like architecture which would have the following properties:
- The creation of an
ElegyModule
class (analogous to theLightningModule
) that would centralize all the JAX-related parts of the training process. More specifically it would be a Pytree and would expose a framework agnostic API, this means Treeo’s Kind system would not be used now. Model
will now be a regular non-pytree Python object that would contain astate: ElegyModule
field that it would maintain and update inplace.
Issue Analytics
- State:
- Created a year ago
- Reactions:4
- Comments:8 (5 by maintainers)
Top GitHub Comments
Ah yes, I think it’s fine to re-export
jax_metrics
since the functions live in some submodules, i.e. we havejm.losses.Crossentropy()
instead ofjm.Crossentropy()
. (I would sayjm.losses.CrossEntropy()
is a better name though, otherwise the naming convention isn’t really consistent)Hey @nalzok thanks for taking the time to write this, opinions of any kind are welcomed! This comment will also serve as an update of how implementation evolved:
Given the proposal also had an
apply
method, ultimately it was simpler to have amutable: bool
argument inapply
which by default isTrue
so previous example look identical withapply
.I too like the name
Trainer
, however I am hesitant to make the change since it will break code that just uses the high-level API. Maybe we could rename it toTrainer
and haveModel
as an alias for backward compatibility.The thing is that Treeo Kind’s are additional metadata that is added to the pytree leaves in order create more powerful filters, this mirrored Flax’s collections. While they simplified parts of the implementation a lot, users have to learn this additional framework. The solution is to have regular pytree and have the user override a couple of additional methods (this can be automated for specific frameworks.
This is currently being implemented in poets-ai/elegy#232, here is an update the resulting APIs:
train_step
,test_step
,pred_step
,init_step
managed_train_step
,managed_test_step
,managed_pred_step
,managed_init_step
init
,apply