question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

AutoRegressive Decoding currently fails if input prompt > 1

See original GitHub issue

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

Problem you have encountered:

I want to run a model based on flax.linen.SelfAttention in auto-regressive mode and pass an input prompt > 1. This however does not seem possible at the moment, e.g.:

import jax
import jax.numpy as jnp
from flax.linen import SelfAttention

attn_layer = SelfAttention(1, decode=True, use_bias=False)

batch_size = 1
max_decoder_length = 4
hidden_size = 2
prompt_length = 2   # setting this to 1 would work

init_variables = attn_layer.init(jax.random.PRNGKey(0), jnp.ones((batch_size, max_decoder_length, hidden_size)), deterministic=True)

params = init_variables["params"]
cache = init_variables["cache"]

dummy_prompt = jnp.arange(batch_size * prompt_length * hidden_size).reshape((batch_size, prompt_length, hidden_size))

output, cache = attn_layer.apply({"params": params, "cache": cache}, dummy_prompt, mutable=["cache"], deterministic=True)

leads to an error. Also check this notebook.

What you expected to happen:

Instead, the code should work and the first len(prompt_length) cache variables should be stored.

Logs, error messages, etc:

~/python_bin/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    273     _context.module_stack.append(self)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:
    277         filter_fn = _context.capture_stack[-1]

~/python_bin/flax/linen/attention.py in __call__(self, inputs_q, inputs_kv, mask, deterministic)
    265         expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
    266         if expected_shape != query.shape:
--> 267           raise ValueError('Autoregressive cache shape error, '
    268                            'expected query shape %s instead got %s.' %
    269                            (expected_shape, query.shape))

ValueError: Autoregressive cache shape error, expected query shape (1, 1, 1, 2) instead got (1, 2, 1, 2).

Steps to reproduce:

Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.

See code/colab above

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
levskayacommented, May 14, 2021

Hey! I commented on #1316 but hadn’t seen these comments when writing that. So I think it sounds like you want to do a single-pass “teacher-forced” cache-initialization. We’d need to add a third “operating mode” for the attention layer (basically, run normally but then stuff the keys and values for the first N tokens into a cache).

1reaction
marcvanzeecommented, May 14, 2021

I want to run a model based on flax.linen.SelfAttention in auto-regressive mode and pass an input prompt > 1.

Apologies if I am misunderstanding, but for fast decoding we deliberately only feed 1 token at a time and assume the caller iterates over the input and maintains the cache. So you could do something like this in your code:

import jax
import jax.numpy as jnp
from flax.linen import SelfAttention
from jax import lax

attn_layer = SelfAttention(1, decode=True, use_bias=False)

batch_size = 1
max_decoder_length = 4
hidden_size = 2
prompt_length = 2
output_dim = 2

init_variables = attn_layer.init(jax.random.PRNGKey(0), jnp.ones((batch_size, max_decoder_length, hidden_size)), deterministic=True)

params = init_variables["params"]
cache = init_variables["cache"]

prompts = jnp.arange(batch_size * prompt_length * hidden_size).reshape((batch_size, prompt_length, hidden_size))
outputs = jnp.zeros((prompts.shape))  # This will be filled one token at a time

for i in range(prompt_length):
  prompt = jax.lax.slice_in_dim(prompts, i, i+1, axis=1)
  output, mutable_vars = attn_layer.apply({"params": params, "cache": cache}, prompt, mutable=["cache"], deterministic=True)
  outputs = lax.dynamic_update_slice(outputs, output, (0, i, 0))
  cache = mutable_vars["cache"]  # Update cache for next iteration

print(outputs)

Are you saying that this approach doesn’t work for your use case? If so, could you please explain why not?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Summary of the models - Hugging Face
Multimodal models mix text inputs with other kinds (e.g. images) and are more specific to a given task. Decoders or autoregressive models. As...
Read more >
Decoding As Dynamic Programming For Recurrent ...
Comment: This paper proposes an approximate inference approach for decoding in autoregressive models, based on the method of auxiliary coordinates, which uses ...
Read more >
Autoregressive Autoencoders | Bounded Rationality
It's an "autoencoder" because it's using the same value x value on the input and output. Figure 1 shows a picture of what...
Read more >
A Survey on Non-Autoregressive Generation for Neural ...
Figure 1: Outline of the survey. We first review the developments of neural machine translation and non-autoregressive related methods.
Read more >
A Survey on Non-Autoregressive Generation for Neural ... - arXiv
1. Outline of the survey. We first review the developments of neural machine translation and non-autoregressive related methods. Then we present ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found