question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Reading incorrect memory on TPU (ft. `flax.linen.transforms.scan`)

See original GitHub issue

Problem you have encountered:

Simple RNNLM-esque example (based on distributed seq2seq example) seems to read wrong/unallocated memory, resulting in nondeterminism on Google Colab TPUs. Using flax.linen.transforms.scan and jax.value_and_grad, both seem necessary to get this behavior, as is the sketchy looking use of multiple LSTM cells after one another (LMCell.__call__)…

What you expected to happen:

I defined a function that looks determinstic, triggered no errors or warnings, and acts “properly” on CPU and GPU, wrapped it in jax.value_and_grad (full code in notebook):

class LMCell(nn.Module):
    @functools.partial(
        nn.transforms.scan,
        length=SEQUENCE_LENGTH-1,
        variable_broadcast='params',
        split_rngs={'params': False},
    )
    @nn.compact
    def __call__(self, lstm_states):
        inputs = self.param('stuff', nn.initializers.zeros, VOCAB_SIZE)
        new_lstm_states = []
        for lstm_state, lstm in zip(lstm_states, [nn.LSTMCell() for _ in range(3)]):
            new_lstm_state, inputs = lstm(lstm_state, inputs)
            new_lstm_states.append(new_lstm_state)
        logits = inputs
        return new_lstm_states, logits

class LM(nn.Module):
    @nn.compact
    def __call__(self):
        lstm_states = [(jnp.zeros(HIDDEN_SIZE, jnp.float32),) * 2] * 3
        _, logits = LMCell()(lstm_states)
        return -jnp.sum(logits)

params = LM().init(
    {'params': jax.random.PRNGKey(0)},
)['params']

@jax.value_and_grad
def f(params):
    return LM().apply(
        {'params': params},
        rngs={'lstm': jax.random.PRNGKey(0)},
    )

I now would expect this snippet should print the same value twice:

print(f(params)[0])
x = jnp.log(-jnp.ones((20000, 20000)))
print(f(params)[0])

…but on TPUs we get 0.0 and nan in the linked Colab.

(Also, as a bonus, trying to remove the kind of useless intermediary module (LM) makes the code shorter and simpler… but also sometimes doesn’t seem to terminate? And if it does yields different but still non-deterministic results? Didn’t try to isolate or analyze this one much though, so feel free to ignore.)

Logs, error messages, etc:

None.

Steps to reproduce:

https://colab.research.google.com/drive/18Tcz0gQp7Eride_cFbw09FNDFPVIoScv?usp=sharing (works the same on stable and master)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
avitalcommented, Jan 11, 2021

I’ll close this since JAX is the right place to discuss this, and there’s a workaround for now. (https://github.com/google/jax/issues/5355). Thanks @sjmielke for reporting!

0reactions
avitalcommented, Jan 6, 2021

@blakehechtman suggested that it may be a silent TPU OOM bug that isn’t captured correctly. @jheek 's point about it not reproducing with smaller nets may suggest this as well. Probably the 20000x20000 array, the negative and the log are filling up device memory (1.5GB each)

Read more comments on GitHub >

github_iconTop Results From Across the Web

flax.linen.transforms - Read the Docs
Jax functional transformations operate on pure functions. Flax extends these transformations to also operate on Module's which have stateful variables and ...
Read more >
Add scripts and weights · esc-bench/wav2vec2-ctc-voxpopuli at ...
+ # By default, the CTC vocab creation would just add them to the vocab even if their occurance is neglectible # compared...
Read more >
HpG - River Thames Conditions
Transformers g1 skalor, Hyrule field main theme piano, Capalaba state college teachers. Alprazolam vs xanax bars, Porky pig robinson crusoe jr, ...
Read more >
Untitled
... https://www.walmart.com/ip/NECHOLOGY-Mens-T-Shirts-Lightweight-T-Shirts-for-Men-Men-Cotton-Linen-Casual-V-Neck-Solid-Short-Sleeve-Pullover-T-Shirt-Mens- ...
Read more >
mn 0 01 05_1 1 10 100 10th 11 11_d0003 12 13 14 141a
... convertible convertibleness convertibly converting convertiplane convertor converts convex convexity convexly convexness convey conveyance conveyancer ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found