Avoid recompilation caused by variably shaped arrays?
See original GitHub issueI’m working with a function that computes a likelihood given a 1D array. The problem is that this array has variable length, and so the jit
’ed function triggers recompilation when the length of this array changes, which happens frequently in my program, causing a significant hit in performance.
I know that there are various levels of abstraction, and see the existence of an Unshaped
arrays in the docs: “JAX can also trace at higher levels of abstraction, like Unshaped, but that’s not currently the default for any transformation”.
First, is it even possible with JAX to compile without declaring the shape a compile-time constant? Would using Unshaped
arrays solve this? If so, how can I tell JAX to trace with this level of abstraction?
Thank you!
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (6 by maintainers)
Another great question!
There are two separate issues here worth disentangling:
The levels of abstraction in JAX pertain to the former. It’s true that JAX can trace-specialize at the Unshaped level, but semi-tangentially there’s an interesting tradeoff between amount of abstraction and the restrictions required for Python code to be traceable. For example, we couldn’t trace
mean0 = lambda x: x.sum(0) / x.shape[0]
at the Unshaped level, while if we specialize on shape that Python function is traceable. (We could tracelambda x: np.sin(np.cos(x))
at Unshaped just fine though.) There’s a relationship to reverse-mode autodiff too: we can transposelambda x: x.sum(0)
specialized on shape, but not unshaped, because we need to know what size to broadcast to. (We could transposelambda x: x * 5.
though.) Just from the point of view of this tradeoff, it’s been a nice sweet spot to specialize on shapes.But you asked about compilation in particular. Let’s set issue 1 aside (say we have a Python function that we can specialize and transform at the Unshaped level, like
lambda x: np.sin(np.cos(x))
). Once we have a staged-out program (i.e. a jaxpr) that is only specialized at the Unshaped level, could we compile one XLA program for it, and thus not have to recompile for different shapes?The answer is no: XLA HLO programs include concrete shapes in their type system. That is, just for issue 2 and unrelated to issue 1, we’d need to specialize the staged-out Unshaped program on argument shapes to be able to generate an XLA HLO program for it, and we’d have to re-invoke the compiler for each new shape specialization. XLA has its own reasons for specializing on shape: for instance, shape-specialized programs can be statically memory allocated and layout-optimized, and decisions like whether to fuse certain operations or rematerialize intermediates depend entirely on the shapes/sizes involved. That’s what gives XLA its incredible optimization power, especially on an accelerator like a TPU.
You could say you’re willing to give up on some of those shape-specialized optimization opportunities if it meant you didn’t have to recompile as often. That makes sense! (In fact, on CPU and GPU we’re used to not compile-time specializing on shapes for GEMM kernels, though they may be selected at runtime based on argument shapes.) People are exploring array-oriented compilers at those different points in the design space, and it’s exciting stuff. I expect JAX will take advantage of these kinds of capabilities as they emerge, in XLA and/or related technologies in the MLIR world. But for now the best compiler (XLA) requires shape-specialized programs. (You could actually lower an Unshaped-specialized jaxpr to a TensorFlow program pretty effectively, though I’m not sure if the TF graph executor (aka interpreter) has the right performance characteristics you’re looking for. Does it?)
But wait! We have a secret weapon here! JAX, transform!
One of the most exciting new JAX transformations is
mask
. It’s a prototype that needs a few weeks more work, and of course it’s totally undocumented as per our custom, but it can help us becausemask
can let us automatically implement shape-polymorphic semantics in shape-monomorphic programs. Here’s what I mean:I’m not sure I’d recommend you use
mask
quite yet, but this is one example of what it can do. (There’s more! While we can dojit
+mask
to avoid recompilations, we can dovmap
+mask
to do ragged batching, e.g. for RNNs. We can also use the same machinery to do import-time shape checking of your code.) There are some examples in the tests.We realized at one point that Unshaped (which doesn’t specialize on any shape information or even rank) was way too big a leap away from Shaped to be useful. Instead,
mask
specializes on partial shape information, in particular on parametrically polymorphic dimensions (which can be concrete too: do as much specialization as you like).The
mask
transform doesn’t do anything you couldn’t do by hand. That is, you can pad and mask things by hand so as to implement shape-polymorphic semantics in a shape-monomorphic way. But it’s a big pain which scales badly with the complexity of the programs you want to handle. It’s just like autodiff: sure you can write derivative code by hand, but it’s much better to have an automatic, composable transformation do it for you.Anyway, here’s the short version of answers to your questions:
No, XLA HLO programs we compile must be shape-specialized, even though JAX can stage some programs out of Python based on less specialization.
No, because while it would avoid trace-specializing on shapes, we have no way to compile a single program with no shape specialization. (Forward-looking, it’s probably better to specialize on some shape information, even if just parametric polymorphism.)
There’s no user API for that, though if you want to fiddle with it for fun you can change
make_jaxpr
not to usexla.abstractify
(which is effectively justlambda x: raise_to_shaped(core.get_aval(x))
) and instead abstract to the Unshaped level (maybe likelambda x: Unshaped(core.get_aval(x).dtype)
).Hope that answers your question, though unfortunately I’m answering it in the negative. You could play with
mask
, but I suspect it might be tricky until we post at least minimal documentation, and even then we don’t know whether it’ll work well for your use case (are you on CPU or GPU?).I’m going to close this issue on the hunch that this was enough of an answer, but please reopen (or open new ones) as needed. Questions are always welcome!
Any updates with the
mask
transformation?