question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Inconsistencies and divergence depending on use of JIT

See original GitHub issue

It seems that on some machines computational results differ significantly if jit is applied.

I have come across this odd behavior in an implementation of a batched Monte Carlo integration. On some machines, when part of the code is jit transformed, the results are significantly off and some inf values occur. This result seems to depend on ostensibly irrelevant code (adding zero times a no-nan expression), and the specific sampling method. Due to this nature I could not pin down the issue to a single expression; neither am I entirely sure I haven’t missed something. Below is a description of the algorithm, as minimal as I could make it with the error occurring, followed by a summary of the different behaviors.

The code

The code consists of the following steps:

  1. Sample complex points as the solution to a polynomial equation (p+tq for fixed p, q such that sum((p+tq)^5)=0).
  2. For each sample point (each point consists of 4 complex numbers) compute a weight using jax.grad.
  3. Take the mean over batch_size of these weights as one batch-step.
  4. Iterate over a given number of batches and add up all means obtained from the batch-steps.

The following is a script taking two true/false arguments: whether to apply jit and whether to use the fori_loop (the combination true flase takes very long to compile). It contains parts to save the samples, so the weights that should have been obtained can be checked afterwards, and it can be excluded that the error occurs because something changes about the sampling (up to numerical error the samples are the same independent of jit use – as they should be using the same keys).

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import numpy as onp
import functools
import sys


def pop(arr, index):
    # drop and return value at index in arr
    rolled = jnp.roll(arr, -index, axis=0)
    return rolled[0], jnp.roll(rolled[1:], index, axis=0)


def insert_col(mat, col, index):
    # insert column col at index into matrix mat
    mat = jnp.roll(mat, -index, axis=1)
    mat = jnp.concatenate([col.reshape(-1, 1), mat], axis=1)
    return jnp.roll(mat, index, axis=1)


@functools.partial(jax.grad, argnums=0, holomorphic=True)
def grad_eqn(z, p):
    return jnp.sum(z ** 5) + p * jnp.prod(z)


@functools.partial(jax.vmap, in_axes=(0, None))
def weight(z, p):
    grads = grad_eqn(z, p)
    dep = jnp.argmax(jnp.abs(grads))
    
    grad_max, grad_rest = pop(grads, dep)
    col = (-grad_rest / grad_max)[:, None]
    
    mat = jnp.concatenate((jnp.eye(3, dtype=jnp.complex_), col), axis=1)
    mat = mat @ mat.T.conj()
    det = jnp.linalg.det(mat).real

    return 1 / det 


def sample_sphere(key, count, dim):
    points = jax.random.normal(key, (count, dim))
    return points / jnp.linalg.norm(points, axis=1, keepdims=True)


@jax.vmap
def solve_poly(p, q):
    # polynomial in t given by (q + t * p)**5
    coeffs = jnp.array([q**5, 5 * p**1 * q**4, 10 * p**2 * q**3,
                        10 * p**3 * q**2, 5 * p**4 * q, p**5])
    coeffs = jnp.sum(coeffs, axis=1)

    roots = jnp.roots(coeffs, strip_zeros=False)
    return p.reshape(1, -1) + roots.reshape(-1, 1) * q.reshape(1, -1)


def sample_poly(key, count):
    # solution has multiplicity 5, need count / 5 q's and p's
    base_count = jnp.ceil(count / 5).astype(int)

    # sample base_count comples p's and q's
    pqs = sample_sphere(key, 2, 2 * base_count * 5)
    ps, qs = (pqs[0] + 1j * pqs[1]).reshape(2, base_count, 5)

    sol = solve_poly(ps, qs)
    return sol.reshape(-1, 5)[:count, :]


@jax.vmap
def divide_largest(z):
    # divide by and drop largest absolute entry
    z0, z = pop(z, jnp.argmax(jnp.abs(z)))
    return z / z0


def monte_carlo(key, batches, batch_size, fori=True):
    keys = jax.random.split(key, batches)

    def batch_step(i, data):
        mean, samples = data
        key_sample = keys[i]

        zs = sample_poly(key_sample, batch_size)
        zs = divide_largest(zs)
        # save samples
        samples = jax.ops.index_update(samples, jax.ops.index[i, :, :], zs)

        weights = weight(zs, jnp.array(0.))
        return mean + jnp.mean(weights), samples

    mean = jnp.array(0.)
    samples = jnp.zeros((batches, batch_size, 4), dtype=jnp.complex_)

    if fori:
        mean, samples = jax.lax.fori_loop(0, batches, batch_step, (mean, samples))
    else:
        for i in range(batches):
            mean, samples = batch_step(i, (mean, samples))

    return mean, samples


if __name__ == '__main__':
    key = jax.random.PRNGKey(0)

    apply_jit = len(sys.argv) > 1 and sys.argv[1] == 'true'
    fori_loop = len(sys.argv) > 2 and sys.argv[2] == 'true'

    niter = 10
    batches = 51
    batch_size = 1000

    mc = functools.partial(monte_carlo, fori=fori_loop,
            batches=batches, batch_size=batch_size)

    if apply_jit:
        mc = jax.jit(mc)
        save_name = 'samples-jit-%i.npy' 
    else:
        save_name = 'samples-%i.npy'

    # skip some keys
    for i in range(25):
        _, key = jax.random.split(key)

    for i in range(niter):
        k0, key = jax.random.split(key)
        mean, samples = mc(k0)
        print(mean)
        # save samples to manually check computations
        # onp.save(save_name % i, samples)

Behavior

As noted, the sample values do not differ depending on jit and fori_loop combination. Computing the weights and means of weights manually from saved sample values always gives finite numerical values which are consistent with the ones obtained by no jit and no fori_loop use ($ python script.py false false). Depending on the computer, both cases in which fori_loop is used may give wrong values containing inf’s. This occurred on both local machines I have tested with. Running the same code on colab, however, gives the right (and same) results in all combinations (which is why I suspect there is an underlying issue, not one in the code).

The following are the first 10 results obtained with the above script in two different environments and various combinations of jit and fori_loop:

XPS; false, false XPS; true, true XPS; false, true Colab; true, true
38.87047976604907 35.23827002862667 35.167724443321404 38.904195431290844
38.85501838205715 inf 35.21379197263621 38.875554832009456
38.87232142336747 35.07552733029048 35.16629159384102 38.9613642029768
38.82467883296542 35.268796318296 35.18550169177784 38.86870896981942
38.875347911324106 35.065090432638506 35.12925896136021 38.91082515791209
38.81607498879701 35.045350301233476 35.087313851691306 38.84038161357294
38.884758144142545 35.204243102525254 35.19112069680813 38.97964735892668
38.884639882640634 inf 35.23049623201075 38.907215776623836
38.96790493327401 inf 35.311082582397795 38.90340030598595
38.91302814793844 35.26023361519001 35.243122471869846 38.87890524435126

None of the complexities in the code seem to be removable without making the behavior disappear.

  • If the sampling is replaced by just uniform or normal random numbers no more inf’s appear and all combinations give the same results.
  • Removing the wrapper function (which uses the fori_loop) around a batched step, and instead just returning the mean from a single batch removes the issue and all results are the same.
  • The gradient is taken of a function sum(z ** 5) + p * prod(z) where for p always p=0 is passed. Given this fact, the gradient can be manually replaced with 5 * z**4 (in the real application the gradient would potentially be less simple), which again removes the erroneous behavior.
  • The most peculiar dependency on the specific implementation is the following: since p=0, the term + p * prod(z) should not change the results. Removing it, however, also removes the issue (no nans and values ~38 not ~35). Even if present in the modified form + 0 * jnp.nan_to_num(p * jnp.prod(z)) it reintroduces the error.

Summary

The erroneous values seem to be connected with the use of fori_loop and the gradient of prod multiplied such that it should vanish. The behavior seems contingent on the random sampling used, making it difficult narrow down to a specific expression that is responsible. Specifically, computing the weights after the samples are computed gives the right results and doesn’t reproduce the erroneous behavior. Any thoughts about how to narrow down here would be appreciated.

Testing environment

All tests were run with the CPU version of jax and jaxlib installed via pip. The current jax version on colab is 0.1.69.

The numerical results above were obtained on a Dell XPS 13 with i5-7200U CPU, jax version 0.1.70, and python 3.8.3. I also saw the same behavior on a desktop machine with jax 0.1.72, Xeon Silver 4114 CPU, and python 3.6.9. I’m not sure what other environment variables may be relevant.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Jun 30, 2020

PR #3611 seems to fix the problem for me!

1reaction
hawkinspcommented, Jun 30, 2020

I think what’s happening here is that our 2-pass implementation of argmax is breaking for this benchmark because XLA chooses to recompute the the input for each pass, and it ends up at least a little bit different, breaking the exact equality the algorithm expects.

Here’s a smaller reproduction of what I believe to be going wrong:

import functools

import jax
import jax.numpy as jnp
import numpy as onp

from jax.config import config
config.update("jax_enable_x64", True)

@functools.partial(jax.grad, argnums=0, holomorphic=True)
def grad_eqn(z, p):
  return jnp.sum(z**5) + p * jnp.prod(z)

@jax.vmap
def foo(pp):
  w = pp
  grads = grad_eqn(w, jnp.array(0.))
  dep = jnp.argmax(jnp.abs(grads))
  return dep

pp = onp.array([
    [
        0.013914733367254616 + 9.3094023595025388e-01j, 0.7261907418672447 -
        3.5631603853758342e-01j, 0.4192612993988415 + 1.5784670831992959e-01j,
        -0.3979203076971916 + 9.0370153633300654e-01j
    ],
    [
        0.7412515559421278 + 5.8209475619094264e-01j, -0.7882039695924775 +
        8.2271693304329360e-02j, 0.2473007941817975 + 2.9011353191281253e-01j,
        0.3953279913354844 - 8.7640244493755098e-02j
    ],
    [
        0.42107995148850474 - 5.2786845011285044e-01j,
        -0.9752016614993378 + 4.6493984002048627e-02j, 0.14899616328175808 -
        2.5660252776982036e-02j, 0.7597534752633921 - 2.9817585059867030e-01j
    ],
    [
        0.07919062080298368 + 2.0443288812872526e-01j, -0.9144440581831745 +
        5.0375688239721973e-02j, 0.7189319378618075 - 4.7781041939994917e-01j,
        0.28616184673974954 + 5.8406319450724586e-01j
    ],
    [
        -0.9124599088445945 + 2.0523072700406764e-02j,
        -0.49751861383647644 + 6.6934941229283518e-01j, -0.09127256955394417 -
        2.5171143504173576e-01j, 0.6288139554943047 - 6.0843626330134115e-01j
    ],
    [
        0.10383981031500984 - 5.8961771394978757e-01j,
        0.16722015050067904 - 5.4033783049084971e-01j,
        -0.3570686716415365 + 7.9098843585535017e-01j,
        -0.36717711934106606 - 8.7325230201503201e-01j
    ],
    [
        -0.4246625385849623 - 1.0866918975223600e-01j, -0.08009235275156984 -
        5.0556054151587737e-01j, 0.4353126751296819 - 1.7484451023920469e-01j,
        -0.29468982612935085 + 9.4805376113112527e-01j
    ],
    [
        -0.33924468705124716 + 9.3201718347059126e-01j, -0.31334017694099114 -
        1.0582535646176845e-01j, 0.6724881365502651 + 2.2066280675715644e-01j,
        -0.5552091363575214 - 1.4490197928805444e-02j
    ],
    [
        0.6526283624759489 + 4.9140814443027613e-02j, 0.723922878773778 +
        3.9038679286342901e-02j, 0.762488834126774 - 6.2337884354704665e-01j,
        -0.39085070216092266 + 7.9604864742611015e-01j
    ],
    [
        0.48937363384974575 - 5.1103393766115682e-02j, 0.7539474954530498 -
        3.3604089724531272e-01j, 0.7479090648182917 - 6.4041309134555557e-01j,
        -0.34829014191239205 + 3.2357526022662669e-01j
    ],
    [
        -0.09058883804930377 + 3.5194499540548291e-01j,
        -0.013080375527656812 - 5.3701680483099290e-01j, -0.7690608912513537 -
        7.1176000810729834e-02j, 0.7408780750401596 - 5.8939023080126030e-01j
    ],
    [
        -0.1448148911478361 + 1.1680988499664302e-01j,
        0.6357579332535874 + 4.8007163633817268e-01j, -0.20356990725787427 -
        9.2667400504246511e-01j, -0.486415685914065 - 6.7109379995585339e-01j
    ],
    [
        0.6208929968646449 - 2.6040937438741862e-01j,
        0.18721272656033006 + 4.2043707645400447e-01j, -0.29536647248428616 +
        7.3555331996715753e-02j, -0.9936616656553029 + 2.7635687670210948e-02j
    ],
    [
        0.12004262077972877 + 1.2780756181079903e-01j, -0.3058488919814152 +
        9.5148878693497396e-01j, 0.3781176412486092 - 1.4516939708606397e-01j,
        0.20414911926428903 + 2.8066844435687388e-01j
    ],
    [
        0.44393385508190725 + 4.1301429499018250e-01j,
        0.775660315846601 - 6.1108689231453883e-01j, -0.22979601723668555 +
        5.6425408946759868e-02j, 0.6348633420533681 - 2.1792338815000339e-01j
    ],
    [
        0.5903309747549792 + 5.7825608084895608e-04j, -0.9688715927965774 -
        8.0116867317087063e-03j, 0.6118501784052073 + 4.1552864146077695e-01j,
        0.011088080369843475 - 1.2357486664766727e-01j
    ]
],
               dtype=onp.complex128)
nojit = foo(pp)
withjit = jax.jit(foo)(pp)
print("  no jit: ", nojit)
print("with jit: ", withjit)
onp.testing.assert_allclose(nojit, withjit)

Output:

  no jit:  [3 0 1 1 0 3 3 0 2 2 3 2 3 1 1 1]
with jit:  [                  3                   0                   1
                   1                   0                   3
                   3                   0                   2
                   2                   3                   2
                   3                   1 9223372036854775807
 9223372036854775807]
Traceback (most recent call last):
  File "u.py", line 109, in <module>
    onp.testing.assert_allclose(nojit, withjit)
  File "/Users/phawkins/.pyenv/versions/py3.7.4/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1528, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/Users/phawkins/.pyenv/versions/py3.7.4/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 840, in assert_array_compare
    raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 2 / 16 (12.5%)
Max absolute difference: 9223372036854775806
Max relative difference: 1.
 x: array([3, 0, 1, 1, 0, 3, 3, 0, 2, 2, 3, 2, 3, 1, 1, 1], dtype=int64)
 y: array([                  3,                   0,                   1,
                         1,                   0,                   3,
                         3,                   0,                   2,...

The fix is probably to switch our argmax implementation to use a 1-pass algorithm via a variadic reduction.

Read more comments on GitHub >

github_iconTop Results From Across the Web

A comparative study of Just-In-Time (JIT) and Theory of ...
The success of a production system truly depends on the manufacturing environment rather than the philosophy being used. This study uses.
Read more >
Just-in-Time (JIT): Definition, Example, and Pros & Cons
A just-in-time (JIT) inventory system is a management strategy that aligns raw-material orders from suppliers directly with production schedules.
Read more >
Relationship between JIT and TQM: Practices and ...
We propose that the use of total quality management (TQM) practices will improve just-in-time (JIT) performance through process variance.
Read more >
The Sharp Bits — JAX documentation
python control flow + JIT. Using control flow with jit is more complicated, and by default it has more constraints. So does this:...
Read more >
Improving Execution Efficiency of Just-in-time Compilation ...
In this paper, we propose Pyper, a JIT compilation-based query ... utilization of GPU hardware due to divergent execution and.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found