question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Design suggestion: allow objects to be closed under JAX's transformation rules

See original GitHub issue

One of the problems with making random number generators into objects (https://github.com/google/jax/issues/2294) is that objects are not closed under JAX’s transformation rules like batching. I’m creating this issue as a discussion point.

It seems like what’s necessary is an abstract base class for objects that need to support transformation rules. I don’t know enough to propose something too concrete, but maybe:

class JaxArrayLike:

    @abstractmethod
    def as_jax_array(self):
        """
        Returns T where T is either a JAX array or a sequence of T.
        """
        raise NotImplementedError

Currently, a simple neural network object is a mess of static members:

class NeuralNetwork:

    def __init__(self, sizes, key):
        keys = random.split(key, len(sizes))
        self.sizes = sizes
        self.weights = [self.random_layer_params(m, n, k)
                        for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

    # Initialization ----------------------------------------------------------
    @staticmethod
    def random_layer_params(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return (scale * random.normal(w_key, (n, m)),
                scale * random.normal(b_key, (n,)))

    # Learning ----------------------------------------------------------------
    @staticmethod
    @jit
    def _update(weights, images, targets):
        grads = grad(BaseNeuralNetwork.loss)(weights, images, targets)
        return [(w - step_size * dw, b - step_size * db)
                for (w, b), (dw, db) in zip(weights, grads)]

    @staticmethod
    def loss(weights, images, targets):
        preds = batched_predict(weights, images)
        return -jnp.sum(preds * targets)

    def update(self, images, targets):
        self.weights = self._update(self.weights, images, targets)

    def accuracy(self, images, targets):
        target_class = jnp.argmax(targets, axis=1)
        predicted_class = jnp.argmax(
            batched_predict(self.weights, images), axis=1)
        return jnp.mean(predicted_class == target_class)


# Inference -------------------------------------------------------------------
def predict(weights, image):
    # per-example predictions
    activations = image
    for w, b in weights[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)

    final_w, final_b = weights[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)


batched_predict = vmap(predict, in_axes=(None, 0))

But by deriving from the abstract class, hopefully most of those static methods could be replaced with regular methods, and as_jax_array could return self.weights. Even better would be to add object-oriented structure to the weights (make them say a list or graph of of objects), which works as longs as as_jax_array is also called on the components of its return value.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:17 (16 by maintainers)

github_iconTop GitHub Comments

1reaction
jekbradburycommented, Mar 11, 2020

For example, in a model whose inference is done by a call to scan, is there any harm in having the parameters in the carry? Does the compilation of the scan recognize that these parameters don’t change, or should I move the parameters to a different object that is passed into the scan function?

It can’t hurt to make sure constant parameters like this are part of the scan’s non-carried arguments (or are closed over by its body function) rather than being part of its carry. We don’t do any optimization in JAX to detect constant-over-time values in carry position, so we’d be relying on XLA to perform loop-invariant code motion (and XLA’s optimizations on loops are generally less mature than for straight-line code).

1reaction
mattjjcommented, Mar 10, 2020

While we like to work with users on code in general, what I think we’re talking about in this issue is changes/additions to JAX’s function transformations API to be OOP-y. That’s just not on our roadmap. (You can always build your own OOPy interface on top of JAX however you like!)

I wish I had the bandwidth to dig into 500 lines of code right now, but unfortunately I don’t. In general, the time it takes to get a response to an issue is inversely proportional to the length of the issue description (perhaps raised to some greater-than-1 power). If you can distill specific, concrete challenges you have, please open issues for those.

Well, what if you want vmap to also vectorize some (but not all) of the members like self.val? This is what I’m running into.

That’s a pretty concrete question!

You can play a lot of Python games to automatically get at the real functions underneath your objects. Here’s one example:

import jax.numpy as np
from jax import vmap

class A:
  def __init__(self, y, z):
    self.y = y  # let's map over this one
    self.z = z  # but not this one

  def foo(self, x):
    return np.sin(x) + np.cos(self.y) + self.z

def my_vmap(method):
  self = method.__self__
  cls = self.__class__
  def function(x, y, z):
    return cls(y, z).foo(x)
  return lambda x: vmap(function, (0, 0, None))(x, self.y, self.z)

out = my_vmap(A(np.arange(3.), 4).foo)(np.arange(3.))
print(out)  # [5.        5.381773  4.4931507]

Of course, you can abstract that to be more general. And if you don’t want to mess with Python’s runtime object representation, you can just define your own methods/abstractions to do the same thing.

Let’s close this issue, since its current framing (as indicated by the issue title) is pretty big and not something we’re going to pursue, but please follow up with small concrete real-use-case issues 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax._src.lax.lax - JAX documentation - Read the Docs
These preserve information about axis identity that may be useful for advanced transformation rules. Args: operand: array to be reshaped. new_sizes: ...
Read more >
Design Submission Requirements Manual
This New York District Design Submission Requirements Manual prescribes standard procedures and instructions to accomplish the required design,.
Read more >
Rax: Composable Learning-to-Rank Using JAX - Google AI Blog
Rax, by design, allows the approximate and gumbel transformations to be freely used with all metrics that are offered by the library, including ......
Read more >
Chapter 13 Building RESTful Web Services with JAX-RS
This chapter describes the REST architecture, RESTful web services, and the Java API for RESTful Web Services (JAX-RS, defined in JSR 311).
Read more >
A FIRST NATIONS VOICE IN THE CONSTITUTION
This Design Issues Report has been produced for the Referendum Council to identify the ... Government in law and policy making for Indigenous...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found