Design suggestion: allow objects to be closed under JAX's transformation rules
See original GitHub issueOne of the problems with making random number generators into objects (https://github.com/google/jax/issues/2294) is that objects are not closed under JAX’s transformation rules like batching. I’m creating this issue as a discussion point.
It seems like what’s necessary is an abstract base class for objects that need to support transformation rules. I don’t know enough to propose something too concrete, but maybe:
class JaxArrayLike:
@abstractmethod
def as_jax_array(self):
"""
Returns T where T is either a JAX array or a sequence of T.
"""
raise NotImplementedError
Currently, a simple neural network object is a mess of static members:
class NeuralNetwork:
def __init__(self, sizes, key):
keys = random.split(key, len(sizes))
self.sizes = sizes
self.weights = [self.random_layer_params(m, n, k)
for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
# Initialization ----------------------------------------------------------
@staticmethod
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return (scale * random.normal(w_key, (n, m)),
scale * random.normal(b_key, (n,)))
# Learning ----------------------------------------------------------------
@staticmethod
@jit
def _update(weights, images, targets):
grads = grad(BaseNeuralNetwork.loss)(weights, images, targets)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(weights, grads)]
@staticmethod
def loss(weights, images, targets):
preds = batched_predict(weights, images)
return -jnp.sum(preds * targets)
def update(self, images, targets):
self.weights = self._update(self.weights, images, targets)
def accuracy(self, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(
batched_predict(self.weights, images), axis=1)
return jnp.mean(predicted_class == target_class)
# Inference -------------------------------------------------------------------
def predict(weights, image):
# per-example predictions
activations = image
for w, b in weights[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = weights[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
batched_predict = vmap(predict, in_axes=(None, 0))
But by deriving from the abstract class, hopefully most of those static methods could be replaced with regular methods, and as_jax_array
could return self.weights
. Even better would be to add object-oriented structure to the weights (make them say a list or graph of of objects), which works as longs as as_jax_array
is also called on the components of its return value.
Issue Analytics
- State:
- Created 4 years ago
- Comments:17 (16 by maintainers)
It can’t hurt to make sure constant parameters like this are part of the scan’s non-carried arguments (or are closed over by its body function) rather than being part of its carry. We don’t do any optimization in JAX to detect constant-over-time values in carry position, so we’d be relying on XLA to perform loop-invariant code motion (and XLA’s optimizations on loops are generally less mature than for straight-line code).
While we like to work with users on code in general, what I think we’re talking about in this issue is changes/additions to JAX’s function transformations API to be OOP-y. That’s just not on our roadmap. (You can always build your own OOPy interface on top of JAX however you like!)
I wish I had the bandwidth to dig into 500 lines of code right now, but unfortunately I don’t. In general, the time it takes to get a response to an issue is inversely proportional to the length of the issue description (perhaps raised to some greater-than-1 power). If you can distill specific, concrete challenges you have, please open issues for those.
That’s a pretty concrete question!
You can play a lot of Python games to automatically get at the real functions underneath your objects. Here’s one example:
Of course, you can abstract that to be more general. And if you don’t want to mess with Python’s runtime object representation, you can just define your own methods/abstractions to do the same thing.
Let’s close this issue, since its current framing (as indicated by the issue title) is pretty big and not something we’re going to pursue, but please follow up with small concrete real-use-case issues 😃