Stax modules as functions using a parameter tracer
See original GitHub issueStax’s functional api is in a way very elegant and fits nicely with Jax’s overall style but I think there are a few places that can be improved without deviating from the Jax feel or start resorting to OO.
A pet-peeve of mine is how tedious it can be when you want to combine different modules in non-conventional ways (e.g., a combination of serial, parallel with parameter reuse). There doesn’t seem that you can do much better than implementing both an init and apply function even when this custom module doesn’t create parameters itself. The result is that you have to essential write the flow of which module to call when twice but following different and somewhat incompatible function signatures.
For example, when wanting to reuse parameters, you can’t the apply
calls in init
since you would want to avoid initializing twice. But now you lose the shape inference and have to hack in a solution either by writing the shape inference by hand, or wrapping things in a function and using jax.eval_shape
. Since modules without parameters only need an apply function to define, e.g., ones that only serve to connect others together, it should be possible to have an API that only requires that one function while handling generating the initialization function itself.
I’ve been toying with the idea of leveraging jax’s tracing to achieve this. One could define a stax module which when initialize is called, just calls the apply functions with some sort of parameter tracer.
class StaxModule(namedtuple("StaxModule", "init_fun, apply")):
def __call__(self, params, inputs, **kwargs):
# lets just ignore handling `rng` for now
if isinstance(params, ParameterTrace) and self.init_fun is not None:
params.init_with(lambda: self.init_fun(kwargs["rng"], inputs))
return self.apply(params, inputs, **kwargs)
def initialize(self, rng, inputs):
# code to bootstrap tracing and shape inference
param_tracer = ParameterTracer()
out_shape = self.__call__(param_tracer, inputs)
# compute the parameters (or fetch if they're evaluated while tracing)
params = param_tracer.evaluate()
return out_shape, params
Now I’ll admit that I haven’t really worked much with tracers in general so I’m not all that familiar with the best patterns to use. That API/design is probably suboptimal and could merit some tweaks. However, even with this design you could handle cases like this effortlessly:
This would then allow you to do things like this:
@stax.module # some helper function to wrap it in a StaxModule
def Foo(module1, module2):
def apply_fun(params, inputs, **kwargs):
param1, param2 = params
out = module1(param1, inputs, **kwargs)
out = module2(param2, out, **kwargs)
return module1(param1, inputs, **kwargs)
return apply_fun
foo = Foo(module1, module2)
out_shape, params = foo.initialize(rng, inputs)
out = foo(params, inputs)
Optional bonus:
Shape inference
In addition, with that machinery in place, we could simplify the init_fun
API and have it just handle creating parameters, leaving the shape inference entirely up to jax’s tracing of the apply_fun
given parameters of a certain shape.
Handling rng
(e.g., splitting, packing/unpacking)
Automatically handling the rng
would essentially require the same machinery as the parameters so it would take very little to set that up at the same time. Where I’m less certain is whether this is something we would even want or not. I’m leaning to thinking that it wouldn’t be too magical if we upacked rng
like a parameters instead of splitting. If you were aware of random.split
, I don’t think this example would feel odd.
@stax.module
def Bar(module1, module2):
def apply_fun(params, inputs, rng, **kwargs):
param1, param2 = params
rng1, rng2, rng3 = rng
out = module1(param1, inputs, rng1, **kwargs)
out = module2(param2, out, rng2, **kwargs)
return module1(param1, inputs, rng3, **kwargs)
return apply_fun
Natural handling of vmap
With all this in place, it would be possible support the use of vmap
on modules without having to worry about its affect on the shapes at initialization (mostly relevant for vectorizing over axes other than 0). We could treat modules like any other function.
Conclusion
I think this kind of machinery would go a long way in removing some of Stax’s rough edges, especially when it comes to supporting unconventional NN architectures. It would make parameter reuse much more user friendly then it currently is without needing to memorize which the modules were already initialize (e.g., trax’s solution). This essentially allows us to keep the modules immutable.
Hopefully I didn’t miss some hidden difficulty but I woudn’t expect something like this to be overly complex. I’m happy to help with its implementation. Though, some pointers for how best to handle the tracing would be very helpful!
As a final example of how it lightens the implementation, here is how stax.serial
and stax.parallel
would look like (with the “automatic” handling of rng
):
@stax.module # some helper function to wrap it in a StaxModule
def serial(*layers):
def apply_fun(params, inputs, rng, **kwargs):
for layer, param, l_rng in zip(layers, params, rng):
inputs = layer(param, inputs, l_rng)
return inputs
return apply_fun
@stax.module
def parallel(*layers):
def apply_fun(params, inputs, rng, **kwargs):
return [
layer(param, l_in, l_rng)
for layer, param, l_in, l_rng in zip(layers, params, inputs, rng)
]
return apply_fun
Edit: thinking back I realized that it might not be possible to tell how many items to generate from a a, b, c = gen
statement. That could be replaced by a more explicit a, b, c = gen.split(3)
call similar to how rngs work.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:7 (4 by maintainers)
This is some very cool work, and similar ideas have been bubbling around inside and outside the team for a few months now (with perhaps the first effort in this direction being @j-towns’s pointy stax).
We definitely agree that stax is limiting, and one of the biggest missing features is precisely this kind of “pointiness” that would let you write complex composite modules once rather than twice when they don’t fit cleanly into stax’s existing “point-free” combinators. Similarly, we’ve also had the instinct that JAX’s tracing machinery should be able to help, since any solution to pointiness essentially involves evaluating the user’s neural network layer definitions in two contexts (while initializing and while applying), which should be a great use case for alternative interpretation.
The specific mechanism proposed here is very compelling (enough that it’s a worthwhile productivity improvement even if it’s implemented in Python outside the JAX machinery, as with jaxnet and a couple libraries that aren’t yet open sourced and work atop JAX similarly to how Sonnet works atop TensorFlow). But unfortunately the way the JAX core works means that it’s impossible to implement it as a tracer: the only things visible to JAX transformations are abstract or concrete values (multidimensional arrays and scalars) and the primitive computations that use them—not Python data structures or attribute access.
Instead, I’ve been working on what I hope is the next best thing: a generic mechanism for tagging intermediate values in JAX computations (#1660) that can be used to build a fairly similar neural net API to the one proposed here. Your
Foo
example would look something like this:My current plan is to work to land the “tagging” infrastructure (#1660) and a minimal “stax2” that uses it, while also enabling authors of existing high-level libraries like jaxnet, trax, and the others to benefit from the infrastructure if they’d like (it should improve the ability to compose those libraries with JAX transformations) without breaking their existing users.
We want JAX to be a great substrate for exploration in neural net API design, especially in a functional style, so we’re always happy to see proposals like this 🙂. And don’t necessarily take my “impossible” as dispositive! The JAX core infrastructure is a land of many wonders and you may well find what you need there to build the library you’re looking for.
Closing this since I feel like this isn’t relevant anymore given many maturing third-party libraries. It is probably best to let other projects develop this side of things and keep jax more focused on the low-level functionality.