Reading incorrect memory on TPU (ft. `flax.linen.transforms.scan`)
See original GitHub issueProblem you have encountered:
Simple RNNLM-esque example (based on distributed seq2seq example) seems to read wrong/unallocated memory, resulting in nondeterminism on Google Colab TPUs. Using flax.linen.transforms.scan
and jax.value_and_grad
, both seem necessary to get this behavior, as is the sketchy looking use of multiple LSTM cells after one another (LMCell.__call__
)…
What you expected to happen:
I defined a function that looks determinstic, triggered no errors or warnings, and acts “properly” on CPU and GPU, wrapped it in jax.value_and_grad
(full code in notebook):
class LMCell(nn.Module):
@functools.partial(
nn.transforms.scan,
length=SEQUENCE_LENGTH-1,
variable_broadcast='params',
split_rngs={'params': False},
)
@nn.compact
def __call__(self, lstm_states):
inputs = self.param('stuff', nn.initializers.zeros, VOCAB_SIZE)
new_lstm_states = []
for lstm_state, lstm in zip(lstm_states, [nn.LSTMCell() for _ in range(3)]):
new_lstm_state, inputs = lstm(lstm_state, inputs)
new_lstm_states.append(new_lstm_state)
logits = inputs
return new_lstm_states, logits
class LM(nn.Module):
@nn.compact
def __call__(self):
lstm_states = [(jnp.zeros(HIDDEN_SIZE, jnp.float32),) * 2] * 3
_, logits = LMCell()(lstm_states)
return -jnp.sum(logits)
params = LM().init(
{'params': jax.random.PRNGKey(0)},
)['params']
@jax.value_and_grad
def f(params):
return LM().apply(
{'params': params},
rngs={'lstm': jax.random.PRNGKey(0)},
)
I now would expect this snippet should print the same value twice:
print(f(params)[0])
x = jnp.log(-jnp.ones((20000, 20000)))
print(f(params)[0])
…but on TPUs we get 0.0
and nan
in the linked Colab.
(Also, as a bonus, trying to remove the kind of useless intermediary module (LM
) makes the code shorter and simpler… but also sometimes doesn’t seem to terminate? And if it does yields different but still non-deterministic results? Didn’t try to isolate or analyze this one much though, so feel free to ignore.)
Logs, error messages, etc:
None.
Steps to reproduce:
https://colab.research.google.com/drive/18Tcz0gQp7Eride_cFbw09FNDFPVIoScv?usp=sharing (works the same on stable and master)
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (1 by maintainers)
Top GitHub Comments
I’ll close this since JAX is the right place to discuss this, and there’s a workaround for now. (https://github.com/google/jax/issues/5355). Thanks @sjmielke for reporting!
@blakehechtman suggested that it may be a silent TPU OOM bug that isn’t captured correctly. @jheek 's point about it not reproducing with smaller nets may suggest this as well. Probably the 20000x20000 array, the negative and the log are filling up device memory (1.5GB each)