question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

inner `jit` functions are re-traced (and re-compiled)

See original GitHub issue

It seems that jit decorators are ignored for nested functions. Consider the following example (simplified from an actual program, where inner_fn is slightly more complex):

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.jit, inline=True)
def inner_fn(state):
    print("entering inner_fn now")
    return 2*state

@jax.jit
def outer_fn(x):
    print("entering outer_fn now")
    old_x = x
    for _ in range(10):
        x = inner_fn(x)
    x = x + old_x
    return x

state = jnp.arange(5, dtype=jnp.uint32)
outer_fn(state).block_until_ready()

Expected behavior: As both functions are decorated with jit, inner_fn should be compiled once and the result re-used by the plain for loop inside of outer_fn. Consequently, the output would show only one emission of “entering inner_fn now”.

Actual behavior: The jit decorator for inner_fn seems to be ignored: “entering inner_fn now” is printed 10 times and setting env var JAX_LOG_COMPILES=1 only prints one line for Compiling outer_fn. In short, we observe the same behavior as if the jit decorator were absent for inner_fn.

Why a bug?: The current behavior negates any of the caching done for jitted functions, resulting in repeated tracing of a function with identical tracers and thus inflated compilation times and can easily surprise the user.

Other considerations:

  1. Using fori_loop: Using a fori_loop results in inner_fn being visited only once, so that will mitigate the issue for this particular example. However, we were actually having troubles with it as it prevents unrolling the loop and our inner_fn was small enough so that we actually were seeing performance hits from cuda kernel launches for inner_fn in this case.
  2. No inline=True for jit decorator of inner_fn: In this case the compilation will compile 10 separate sub-jaxpressions for inner_fn that are not inlined but all invoked via xla_call (when inspecting this via jax.make_jaxpr). We see that the jit decorator is therefore not completely without effect, but this is arguably the worst case (separate xla_call overhead as in the fori_loop AND multiple compilation of inner_fn).

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
  • If applicable, include full error messages/tracebacks.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Jul 1, 2021

A concrete next step on our end is to add thorough logging so that these times are super easy to inspect.

0reactions
lumipcommented, Jul 5, 2021

Alright, I tried to boil it down to the essence in the following. So basically I have this piece of code which repeatedly applies a vector-valued-function ( col_fn ) on all columns and (some form of) diagonals of a (4x4) matrix (this realised by inner_fn and the repeated application as a loop in outer_fn). My basic approach for inner_fn was to first vmap the function col_fn over columns, then reorganise the matrix so that the diagonals I’m interested in are put into columns and apply the vmapped col_fn again. This remapping is what seems to be causing problems for me. Code (col_fn is just a stand-in right now, with a loop to adjust its complexity (in terms of simple operations it applies)):


import jax
import jax.numpy as jnp
import numpy as np

def col_fn(col): # applied to a column of a 4x4 matrix
    for _ in range(3): # number of operations inside can influence fusing behaviour
        col = col << 7
    return col

def inner_fn(m):
    col_vmap = jax.vmap(col_fn, in_axes=1, out_axes=1)
    # col_vmap = col_fn # vmap or no vmap seems to make NO difference

    # apply col_fn to each column
    m = col_vmap(m)

    # then apply col_fn to each diagonal:
    # moving diagonals into columns
    diag_map = np.array([
         0,  1,  2,  3,
         5,  6,  7,  4,
        10, 11,  8,  9,
        15, 12, 13, 14
    ])

    # option 1: jax.lax.gather : results in splits; split always occur after a gather or concatenate, not necessarily aligned with iteration counts
    gdn = jax.lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,1), start_index_map=(0,1))
    indices = jnp.array(((0,0), (0,1), (0,2), (0,3), (1,1), (1,2), (1,3), (1,0), (2,2), (2,3), (2,0), (2,1), (3,3), (3,0), (3,1), (3,2)))
    diag_m = jax.lax.gather(m, indices, gdn, slice_sizes=(1,1), unique_indices=True).reshape((4,4))

    # option 2: jnp.take : same as option1 (but with some additional instructions that appear to be argument checks)
    diag_map = diag_map.reshape((4,4))
    diag_m = jnp.take(m, diag_map)

    # option 3: flatten - lookup - reshape : seems to either fuse completely or create a fuse per iteration in outer_fn, depending on iterations in col_fn
    # diag_m = m.ravel()[diag_map].reshape(4,4)

    # option 4: no index mapping : will always result in a single large fuse (no matter how complex col_fn)
    # diag_m = m
    # diag_m = col_vmap(diag_m)

    return diag_m


@jax.jit
def outer_fn(m):
    old_m = m
    for _ in range(7): # number of iterations influences number of kernels
        m = inner_fn(m)
    m = m + old_m # this gets absorbed into the last split
    return m

m = np.arange(16).reshape((4,4))
outer_fn(m).block_until_ready()

As you can see, I tried different ways of reorganising the matrix and they result in different fusing outcomes, which I investigated by looking at the outputs produced when setting the --xla_dump_hlo_as_text XLA flag.

  • Option 1 and 2 are basically identical and insert random splits after emitted gather operations in HLO (i.e., a fusion will always end with a gather op). This roughly merges 2 iterations of inner_fn, so for 7 iterations I end up with 4 fusions: The first covering the first 1.5 iterations (stopping at the reorganising in the 2nd iterations), the second fusion covers the second half of the 2nd iteration, the entire 3rd and the first half of the 4th, and so on). This seems invariant of the amount of operations in col_fn or the number of iterations of inner_fn.
  • Option 3: This either results in a single fusion combining all iterations OR in one fusion per iteration of inner_fn, depending entirely on the complexity/number of iterations in col_fn: In my tests, I get a single fusion for less than 13 iterations in col_fn and one fusion per inner_fn invocation for more. The splits again occur at the gather ops, so the first fusion covers the first half of the first invocation of inner_fn, each following fusion cover a second half and the following first half of the next iteration.
  • Option 4: No reorganising of the matrix takes place, therefore there are no gather ops emitted and the result is a single fusion in all cases, no matter the complexity of col_fn or the number of iterations of inner_fn.

Further: In reality, col_fn is a bit more involved than shown above and requires interdependent updates to the components of the received vector. Something akin to this (with some randomly chosen elementary operations)


def col_fn_complex(col): # applied to a column of a 4x4 matrix
    a,b,c,d = col
    for _ in range(3): # number of operations inside can influence fusing behavior
        a = a << 7
        b = b + a
        c = c * b
    col = jnp.array([a,b,c,d])
    return col

I haven’t been able to do this without at some point using a concatenating operation to produce the output vector (or using an index_update), which results in even more splits of the fusions in the output HLO dump.

Other things I tried:

  • I have tried a different implementation in which col_fn receives the full matrix and a list of indices instead of a vector slice of the matrix, to avoid having to reorganise the matrix. However, this is merely moving the gather from the reorganising in inner_fn to the then required index lookup at the start of col_fn
  • I have also tried extracting the diagonals by multiplying the matrix with a 0/1-masking matrix followed by a summation along the rows. Doing that, I then get the same result as in Option 3 above, with splits occurring at the reduce ops that replace the gather ops (+ additional add kernels for each iteration that are used by the reduce ops).
Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX hook / information / warning when a JIT function is re ...
Is there another way to be informed and potentially prevent recompilation of a function if it happens to frequently?
Read more >
Just In Time Compilation with JAX
We will discuss the jax.jit() transform, which will perform Just In Time (JIT) compilation of a JAX Python function so it can be...
Read more >
torch.fx — PyTorch 1.13 documentation
It allows you to specify a pattern and replacement function and it will trace through those functions, find instances of the group of...
Read more >
Linters | golangci-lint
Name Description Presets Since asasalint ⚙️ check for pass any as any in variadic func(...any) bugs 1.47.0 bidichk ⚙️ Checks for dangerous unicode character sequences...
Read more >
Transforming OpenGL Debugging to a “White Box” Model{/exp ...
API function breakpoints: The Breakpoint dialog lets a developer choose OpenGL ... The editor also allows editing shader source code, recompiling shaders, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found