question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Stax modules as functions using a parameter tracer

See original GitHub issue

Stax’s functional api is in a way very elegant and fits nicely with Jax’s overall style but I think there are a few places that can be improved without deviating from the Jax feel or start resorting to OO.

A pet-peeve of mine is how tedious it can be when you want to combine different modules in non-conventional ways (e.g., a combination of serial, parallel with parameter reuse). There doesn’t seem that you can do much better than implementing both an init and apply function even when this custom module doesn’t create parameters itself. The result is that you have to essential write the flow of which module to call when twice but following different and somewhat incompatible function signatures.

For example, when wanting to reuse parameters, you can’t the apply calls in init since you would want to avoid initializing twice. But now you lose the shape inference and have to hack in a solution either by writing the shape inference by hand, or wrapping things in a function and using jax.eval_shape. Since modules without parameters only need an apply function to define, e.g., ones that only serve to connect others together, it should be possible to have an API that only requires that one function while handling generating the initialization function itself.

I’ve been toying with the idea of leveraging jax’s tracing to achieve this. One could define a stax module which when initialize is called, just calls the apply functions with some sort of parameter tracer.

class StaxModule(namedtuple("StaxModule", "init_fun, apply")):

  def __call__(self, params, inputs, **kwargs):
    # lets just ignore handling `rng` for now
    if isinstance(params, ParameterTrace) and self.init_fun is not None:
      params.init_with(lambda: self.init_fun(kwargs["rng"], inputs))
    return self.apply(params, inputs, **kwargs)

  def initialize(self, rng, inputs):
    # code to bootstrap tracing and shape inference
    param_tracer = ParameterTracer()
    out_shape = self.__call__(param_tracer, inputs)

    # compute the parameters (or fetch if they're evaluated while tracing)
    params = param_tracer.evaluate()
    
    return out_shape, params

Now I’ll admit that I haven’t really worked much with tracers in general so I’m not all that familiar with the best patterns to use. That API/design is probably suboptimal and could merit some tweaks. However, even with this design you could handle cases like this effortlessly:

This would then allow you to do things like this:

@stax.module  # some helper function to wrap it in a StaxModule
def Foo(module1, module2):

  def apply_fun(params, inputs, **kwargs):
    param1, param2 = params
    out = module1(param1, inputs, **kwargs)
    out = module2(param2, out, **kwargs)
    return module1(param1, inputs, **kwargs)

  return apply_fun

foo = Foo(module1, module2)
out_shape, params = foo.initialize(rng, inputs)

out = foo(params, inputs)

Optional bonus:

Shape inference

In addition, with that machinery in place, we could simplify the init_fun API and have it just handle creating parameters, leaving the shape inference entirely up to jax’s tracing of the apply_fun given parameters of a certain shape.

Handling rng (e.g., splitting, packing/unpacking)

Automatically handling the rng would essentially require the same machinery as the parameters so it would take very little to set that up at the same time. Where I’m less certain is whether this is something we would even want or not. I’m leaning to thinking that it wouldn’t be too magical if we upacked rng like a parameters instead of splitting. If you were aware of random.split, I don’t think this example would feel odd.

@stax.module
def Bar(module1, module2):

  def apply_fun(params, inputs, rng, **kwargs):
    param1, param2 = params
    rng1, rng2, rng3 = rng
    out = module1(param1, inputs, rng1, **kwargs)
    out = module2(param2, out, rng2, **kwargs)
    return module1(param1, inputs, rng3, **kwargs)

  return apply_fun

Natural handling of vmap

With all this in place, it would be possible support the use of vmap on modules without having to worry about its affect on the shapes at initialization (mostly relevant for vectorizing over axes other than 0). We could treat modules like any other function.

Conclusion

I think this kind of machinery would go a long way in removing some of Stax’s rough edges, especially when it comes to supporting unconventional NN architectures. It would make parameter reuse much more user friendly then it currently is without needing to memorize which the modules were already initialize (e.g., trax’s solution). This essentially allows us to keep the modules immutable.

Hopefully I didn’t miss some hidden difficulty but I woudn’t expect something like this to be overly complex. I’m happy to help with its implementation. Though, some pointers for how best to handle the tracing would be very helpful!

As a final example of how it lightens the implementation, here is how stax.serial and stax.parallel would look like (with the “automatic” handling of rng):

@stax.module  # some helper function to wrap it in a StaxModule
def serial(*layers):

  def apply_fun(params, inputs, rng, **kwargs):
    for layer, param, l_rng in zip(layers, params, rng):
      inputs = layer(param, inputs, l_rng)
    return inputs

  return apply_fun

@stax.module 
def parallel(*layers):

  def apply_fun(params, inputs, rng, **kwargs):
    return [
            layer(param, l_in, l_rng)
            for layer, param, l_in, l_rng in zip(layers, params, inputs, rng)
            ]

  return apply_fun

Edit: thinking back I realized that it might not be possible to tell how many items to generate from a a, b, c = gen statement. That could be replaced by a more explicit a, b, c = gen.split(3) call similar to how rngs work.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
jekbradburycommented, Nov 11, 2019

This is some very cool work, and similar ideas have been bubbling around inside and outside the team for a few months now (with perhaps the first effort in this direction being @j-towns’s pointy stax).

We definitely agree that stax is limiting, and one of the biggest missing features is precisely this kind of “pointiness” that would let you write complex composite modules once rather than twice when they don’t fit cleanly into stax’s existing “point-free” combinators. Similarly, we’ve also had the instinct that JAX’s tracing machinery should be able to help, since any solution to pointiness essentially involves evaluating the user’s neural network layer definitions in two contexts (while initializing and while applying), which should be a great use case for alternative interpretation.

The specific mechanism proposed here is very compelling (enough that it’s a worthwhile productivity improvement even if it’s implemented in Python outside the JAX machinery, as with jaxnet and a couple libraries that aren’t yet open sourced and work atop JAX similarly to how Sonnet works atop TensorFlow). But unfortunately the way the JAX core works means that it’s impossible to implement it as a tracer: the only things visible to JAX transformations are abstract or concrete values (multidimensional arrays and scalars) and the primitive computations that use them—not Python data structures or attribute access.

Instead, I’ve been working on what I hope is the next best thing: a generic mechanism for tagging intermediate values in JAX computations (#1660) that can be used to build a fairly similar neural net API to the one proposed here. Your Foo example would look something like this:

def Foo(module1, module2):
  def apply_fun(key, inputs, **kwargs):
    key1, key2 = random.split(key)
    out = module1(key1, inputs, **kwargs)
    out = module2(key2, out, **kwargs)
    return module1(key1, out, **kwargs)
  return apply_fun

foo = Foo(module1, module2)
_, params = collect(foo)(rng, inputs)
out = inject(foo, params)(inputs)

My current plan is to work to land the “tagging” infrastructure (#1660) and a minimal “stax2” that uses it, while also enabling authors of existing high-level libraries like jaxnet, trax, and the others to benefit from the infrastructure if they’d like (it should improve the ability to compose those libraries with JAX transformations) without breaking their existing users.

We want JAX to be a great substrate for exploration in neural net API design, especially in a functional style, so we’re always happy to see proposals like this 🙂. And don’t necessarily take my “impossible” as dispositive! The JAX core infrastructure is a land of many wonders and you may well find what you need there to build the library you’re looking for.

0reactions
gehringcommented, Sep 24, 2020

Closing this since I feel like this isn’t relevant anymore given many maturing third-party libraries. It is probably best to let other projects develop this side of things and keep jax more focused on the low-level functionality.

Read more comments on GitHub >

github_iconTop Results From Across the Web

STAF Execution Engine (STAX) Service User's Guide
Goal: Generate a random number (which could be used to randomly select which function to call) using the random module provided by Python....
Read more >
How to Think in JAX - JAX documentation - Read the Docs
JIT and other JAX transforms work by tracing a function to determine its effect on inputs of a specific shape and type. Variables...
Read more >
Spring Data for Apache Cassandra - Reference Documentation
In this section, we try to provide what we think is an easy-to-follow guide for starting with the Spring Data for Apache Cassandra...
Read more >
cassandra.cluster - DataStax Python Driver
The main class to use when interacting with a Cassandra cluster. ... this should be a function that accepts one argument, the IP...
Read more >
Pyro Primitives - NumPyro documentation
Annotate the given site as an optimizable parameter for use with ... nn (tuple) – a tuple of (init_fn, apply_fn) obtained by a...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found