Inconsistencies and divergence depending on use of JIT
See original GitHub issueIt seems that on some machines computational results differ significantly if jit
is applied.
I have come across this odd behavior in an implementation of a batched Monte Carlo integration. On some machines, when part of the code is jit
transformed, the results are significantly off and some inf
values occur. This result seems to depend on ostensibly irrelevant code (adding zero times a no-nan expression), and the specific sampling method. Due to this nature I could not pin down the issue to a single expression; neither am I entirely sure I haven’t missed something. Below is a description of the algorithm, as minimal as I could make it with the error occurring, followed by a summary of the different behaviors.
The code
The code consists of the following steps:
- Sample complex points as the solution to a polynomial equation (
p+tq
for fixedp
,q
such thatsum((p+tq)^5)=0
). - For each sample point (each point consists of 4 complex numbers) compute a weight using
jax.grad
. - Take the mean over
batch_size
of these weights as one batch-step. - Iterate over a given number of
batches
and add up all means obtained from the batch-steps.
The following is a script taking two true/false arguments: whether to apply jit
and whether to use the fori_loop
(the combination true
flase
takes very long to compile).
It contains parts to save the samples, so the weights that should have been obtained can be checked afterwards, and it can be excluded that the error occurs because something changes about the sampling (up to numerical error the samples are the same independent of jit
use – as they should be using the same keys).
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as onp
import functools
import sys
def pop(arr, index):
# drop and return value at index in arr
rolled = jnp.roll(arr, -index, axis=0)
return rolled[0], jnp.roll(rolled[1:], index, axis=0)
def insert_col(mat, col, index):
# insert column col at index into matrix mat
mat = jnp.roll(mat, -index, axis=1)
mat = jnp.concatenate([col.reshape(-1, 1), mat], axis=1)
return jnp.roll(mat, index, axis=1)
@functools.partial(jax.grad, argnums=0, holomorphic=True)
def grad_eqn(z, p):
return jnp.sum(z ** 5) + p * jnp.prod(z)
@functools.partial(jax.vmap, in_axes=(0, None))
def weight(z, p):
grads = grad_eqn(z, p)
dep = jnp.argmax(jnp.abs(grads))
grad_max, grad_rest = pop(grads, dep)
col = (-grad_rest / grad_max)[:, None]
mat = jnp.concatenate((jnp.eye(3, dtype=jnp.complex_), col), axis=1)
mat = mat @ mat.T.conj()
det = jnp.linalg.det(mat).real
return 1 / det
def sample_sphere(key, count, dim):
points = jax.random.normal(key, (count, dim))
return points / jnp.linalg.norm(points, axis=1, keepdims=True)
@jax.vmap
def solve_poly(p, q):
# polynomial in t given by (q + t * p)**5
coeffs = jnp.array([q**5, 5 * p**1 * q**4, 10 * p**2 * q**3,
10 * p**3 * q**2, 5 * p**4 * q, p**5])
coeffs = jnp.sum(coeffs, axis=1)
roots = jnp.roots(coeffs, strip_zeros=False)
return p.reshape(1, -1) + roots.reshape(-1, 1) * q.reshape(1, -1)
def sample_poly(key, count):
# solution has multiplicity 5, need count / 5 q's and p's
base_count = jnp.ceil(count / 5).astype(int)
# sample base_count comples p's and q's
pqs = sample_sphere(key, 2, 2 * base_count * 5)
ps, qs = (pqs[0] + 1j * pqs[1]).reshape(2, base_count, 5)
sol = solve_poly(ps, qs)
return sol.reshape(-1, 5)[:count, :]
@jax.vmap
def divide_largest(z):
# divide by and drop largest absolute entry
z0, z = pop(z, jnp.argmax(jnp.abs(z)))
return z / z0
def monte_carlo(key, batches, batch_size, fori=True):
keys = jax.random.split(key, batches)
def batch_step(i, data):
mean, samples = data
key_sample = keys[i]
zs = sample_poly(key_sample, batch_size)
zs = divide_largest(zs)
# save samples
samples = jax.ops.index_update(samples, jax.ops.index[i, :, :], zs)
weights = weight(zs, jnp.array(0.))
return mean + jnp.mean(weights), samples
mean = jnp.array(0.)
samples = jnp.zeros((batches, batch_size, 4), dtype=jnp.complex_)
if fori:
mean, samples = jax.lax.fori_loop(0, batches, batch_step, (mean, samples))
else:
for i in range(batches):
mean, samples = batch_step(i, (mean, samples))
return mean, samples
if __name__ == '__main__':
key = jax.random.PRNGKey(0)
apply_jit = len(sys.argv) > 1 and sys.argv[1] == 'true'
fori_loop = len(sys.argv) > 2 and sys.argv[2] == 'true'
niter = 10
batches = 51
batch_size = 1000
mc = functools.partial(monte_carlo, fori=fori_loop,
batches=batches, batch_size=batch_size)
if apply_jit:
mc = jax.jit(mc)
save_name = 'samples-jit-%i.npy'
else:
save_name = 'samples-%i.npy'
# skip some keys
for i in range(25):
_, key = jax.random.split(key)
for i in range(niter):
k0, key = jax.random.split(key)
mean, samples = mc(k0)
print(mean)
# save samples to manually check computations
# onp.save(save_name % i, samples)
Behavior
As noted, the sample values do not differ depending on jit
and fori_loop
combination. Computing the weights and means of weights manually from saved sample values always gives finite numerical values which are consistent with the ones obtained by no jit
and no fori_loop
use ($ python script.py false false
).
Depending on the computer, both cases in which fori_loop
is used may give wrong values containing inf
’s.
This occurred on both local machines I have tested with. Running the same code on colab, however, gives the right (and same) results in all combinations (which is why I suspect there is an underlying issue, not one in the code).
The following are the first 10 results obtained with the above script in two different environments and various combinations of jit
and fori_loop
:
XPS; false, false | XPS; true, true | XPS; false, true | Colab; true, true |
---|---|---|---|
38.87047976604907 | 35.23827002862667 | 35.167724443321404 | 38.904195431290844 |
38.85501838205715 | inf | 35.21379197263621 | 38.875554832009456 |
38.87232142336747 | 35.07552733029048 | 35.16629159384102 | 38.9613642029768 |
38.82467883296542 | 35.268796318296 | 35.18550169177784 | 38.86870896981942 |
38.875347911324106 | 35.065090432638506 | 35.12925896136021 | 38.91082515791209 |
38.81607498879701 | 35.045350301233476 | 35.087313851691306 | 38.84038161357294 |
38.884758144142545 | 35.204243102525254 | 35.19112069680813 | 38.97964735892668 |
38.884639882640634 | inf | 35.23049623201075 | 38.907215776623836 |
38.96790493327401 | inf | 35.311082582397795 | 38.90340030598595 |
38.91302814793844 | 35.26023361519001 | 35.243122471869846 | 38.87890524435126 |
None of the complexities in the code seem to be removable without making the behavior disappear.
- If the sampling is replaced by just uniform or normal random numbers no more
inf
’s appear and all combinations give the same results. - Removing the wrapper function (which uses the
fori_loop
) around a batched step, and instead just returning the mean from a single batch removes the issue and all results are the same. - The gradient is taken of a function
sum(z ** 5) + p * prod(z)
where forp
alwaysp=0
is passed. Given this fact, the gradient can be manually replaced with5 * z**4
(in the real application the gradient would potentially be less simple), which again removes the erroneous behavior. - The most peculiar dependency on the specific implementation is the following: since
p=0
, the term+ p * prod(z)
should not change the results. Removing it, however, also removes the issue (nonan
s and values ~38 not ~35). Even if present in the modified form+ 0 * jnp.nan_to_num(p * jnp.prod(z))
it reintroduces the error.
Summary
The erroneous values seem to be connected with the use of fori_loop
and the gradient of prod
multiplied such that it should vanish. The behavior seems contingent on the random sampling used, making it difficult narrow down to a specific expression that is responsible. Specifically, computing the weights after the samples are computed gives the right results and doesn’t reproduce the erroneous behavior. Any thoughts about how to narrow down here would be appreciated.
Testing environment
All tests were run with the CPU version of jax and jaxlib installed via pip
. The current jax version on colab is 0.1.69
.
The numerical results above were obtained on a Dell XPS 13 with i5-7200U CPU, jax version 0.1.70
, and python 3.8.3
.
I also saw the same behavior on a desktop machine with jax 0.1.72
, Xeon Silver 4114 CPU, and python 3.6.9
.
I’m not sure what other environment variables may be relevant.
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (6 by maintainers)
PR #3611 seems to fix the problem for me!
I think what’s happening here is that our 2-pass implementation of
argmax
is breaking for this benchmark because XLA chooses to recompute the the input for each pass, and it ends up at least a little bit different, breaking the exact equality the algorithm expects.Here’s a smaller reproduction of what I believe to be going wrong:
Output:
The fix is probably to switch our
argmax
implementation to use a 1-pass algorithm via a variadic reduction.