question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Avoid recompilation caused by variably shaped arrays?

See original GitHub issue

I’m working with a function that computes a likelihood given a 1D array. The problem is that this array has variable length, and so the jit’ed function triggers recompilation when the length of this array changes, which happens frequently in my program, causing a significant hit in performance.

I know that there are various levels of abstraction, and see the existence of an Unshaped arrays in the docs: “JAX can also trace at higher levels of abstraction, like Unshaped, but that’s not currently the default for any transformation”.

First, is it even possible with JAX to compile without declaring the shape a compile-time constant? Would using Unshaped arrays solve this? If so, how can I tell JAX to trace with this level of abstraction?

Thank you!

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:13 (6 by maintainers)

github_iconTop GitHub Comments

14reactions
mattjjcommented, Mar 27, 2020

Another great question!

There are two separate issues here worth disentangling:

  1. staging programs out of Python by trace-specializing, and
  2. compiling those staged-out programs using XLA.

The levels of abstraction in JAX pertain to the former. It’s true that JAX can trace-specialize at the Unshaped level, but semi-tangentially there’s an interesting tradeoff between amount of abstraction and the restrictions required for Python code to be traceable. For example, we couldn’t trace mean0 = lambda x: x.sum(0) / x.shape[0] at the Unshaped level, while if we specialize on shape that Python function is traceable. (We could trace lambda x: np.sin(np.cos(x)) at Unshaped just fine though.) There’s a relationship to reverse-mode autodiff too: we can transpose lambda x: x.sum(0) specialized on shape, but not unshaped, because we need to know what size to broadcast to. (We could transpose lambda x: x * 5. though.) Just from the point of view of this tradeoff, it’s been a nice sweet spot to specialize on shapes.

But you asked about compilation in particular. Let’s set issue 1 aside (say we have a Python function that we can specialize and transform at the Unshaped level, like lambda x: np.sin(np.cos(x))). Once we have a staged-out program (i.e. a jaxpr) that is only specialized at the Unshaped level, could we compile one XLA program for it, and thus not have to recompile for different shapes?

The answer is no: XLA HLO programs include concrete shapes in their type system. That is, just for issue 2 and unrelated to issue 1, we’d need to specialize the staged-out Unshaped program on argument shapes to be able to generate an XLA HLO program for it, and we’d have to re-invoke the compiler for each new shape specialization. XLA has its own reasons for specializing on shape: for instance, shape-specialized programs can be statically memory allocated and layout-optimized, and decisions like whether to fuse certain operations or rematerialize intermediates depend entirely on the shapes/sizes involved. That’s what gives XLA its incredible optimization power, especially on an accelerator like a TPU.

You could say you’re willing to give up on some of those shape-specialized optimization opportunities if it meant you didn’t have to recompile as often. That makes sense! (In fact, on CPU and GPU we’re used to not compile-time specializing on shapes for GEMM kernels, though they may be selected at runtime based on argument shapes.) People are exploring array-oriented compilers at those different points in the design space, and it’s exciting stuff. I expect JAX will take advantage of these kinds of capabilities as they emerge, in XLA and/or related technologies in the MLIR world. But for now the best compiler (XLA) requires shape-specialized programs. (You could actually lower an Unshaped-specialized jaxpr to a TensorFlow program pretty effectively, though I’m not sure if the TF graph executor (aka interpreter) has the right performance characteristics you’re looking for. Does it?)

But wait! We have a secret weapon here! JAX, transform!

One of the most exciting new JAX transformations is mask. It’s a prototype that needs a few weeks more work, and of course it’s totally undocumented as per our custom, but it can help us because mask can let us automatically implement shape-polymorphic semantics in shape-monomorphic programs. Here’s what I mean:

import jax.numpy as np
from jax import jit, mask

def bucket_jit(f):
  compiled_f = jit(mask(f, ['n'], ''))
  def wrapped(x):
    amount = 128 - x.shape[0] % 128
    padded_x = np.pad(x, (0, amount))
    return compiled_f([padded_x], dict(n=x.shape[0]))
  return wrapped

@bucket_jit
def foo(x):
  print("recompiling!")  # actually retracing, but effectively correct
  return np.tanh(np.sum(x))

foo(np.arange(4))  # recompiling!
foo(np.arange(5))
foo(np.arange(6))
foo(np.arange(300))  # recompiling!

from jax import grad
grad(foo)(np.arange(3.))  # recompiling!
grad(foo)(np.arange(4.))
grad(foo)(np.arange(129.))  # recompiling!

I’m not sure I’d recommend you use mask quite yet, but this is one example of what it can do. (There’s more! While we can do jit + mask to avoid recompilations, we can do vmap+mask to do ragged batching, e.g. for RNNs. We can also use the same machinery to do import-time shape checking of your code.) There are some examples in the tests.

We realized at one point that Unshaped (which doesn’t specialize on any shape information or even rank) was way too big a leap away from Shaped to be useful. Instead, mask specializes on partial shape information, in particular on parametrically polymorphic dimensions (which can be concrete too: do as much specialization as you like).

The mask transform doesn’t do anything you couldn’t do by hand. That is, you can pad and mask things by hand so as to implement shape-polymorphic semantics in a shape-monomorphic way. But it’s a big pain which scales badly with the complexity of the programs you want to handle. It’s just like autodiff: sure you can write derivative code by hand, but it’s much better to have an automatic, composable transformation do it for you.

Anyway, here’s the short version of answers to your questions:

First, is it even possible with JAX to compile without declaring the shape a compile-time constant?

No, XLA HLO programs we compile must be shape-specialized, even though JAX can stage some programs out of Python based on less specialization.

Would using Unshaped arrays solve this?

No, because while it would avoid trace-specializing on shapes, we have no way to compile a single program with no shape specialization. (Forward-looking, it’s probably better to specialize on some shape information, even if just parametric polymorphism.)

If so, how can I tell JAX to trace with this level of abstraction?

There’s no user API for that, though if you want to fiddle with it for fun you can change make_jaxpr not to use xla.abstractify (which is effectively just lambda x: raise_to_shaped(core.get_aval(x))) and instead abstract to the Unshaped level (maybe like lambda x: Unshaped(core.get_aval(x).dtype)).

Hope that answers your question, though unfortunately I’m answering it in the negative. You could play with mask, but I suspect it might be tricky until we post at least minimal documentation, and even then we don’t know whether it’ll work well for your use case (are you on CPU or GPU?).

I’m going to close this issue on the hunch that this was enough of an answer, but please reopen (or open new ones) as needed. Questions are always welcome!

8reactions
VolodyaCOcommented, Aug 10, 2021

Any updates with the mask transformation?

Read more comments on GitHub >

github_iconTop Results From Across the Web

The Sharp Bits — JAX documentation
JAX transformation and compilation are designed to work only on Python functions ... JAX re-runs the Python function when the type or shape...
Read more >
Jax, jit and dynamic shapes: a regression from Tensorflow?
Hi @jakevdp, I don't think the question is subjective as it relates to capacities of jit compilation of operators on dynamic shapes in...
Read more >
Code Gen Options (The GNU Fortran Compiler)
Local variables or arrays having an explicit SAVE attribute are silently ignored unless the -pedantic option is added. -ff2c. Generate code designed to...
Read more >
cannot pass variable size array to subroutine - Google Groups
> shape declaration can cause copy-in/copy-out to occur for the array, > which can result in a large performance difference. There's no need...
Read more >
Untitled
With --near_data=globals, short pointers are used for the large global array, which can't work and causes a relocation-overflow error.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found