AutoRegressive Decoding currently fails if input prompt > 1
See original GitHub issueProvide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
Problem you have encountered:
I want to run a model based on flax.linen.SelfAttention
in auto-regressive mode and pass an input prompt > 1. This however does not seem possible at the moment, e.g.:
import jax
import jax.numpy as jnp
from flax.linen import SelfAttention
attn_layer = SelfAttention(1, decode=True, use_bias=False)
batch_size = 1
max_decoder_length = 4
hidden_size = 2
prompt_length = 2 # setting this to 1 would work
init_variables = attn_layer.init(jax.random.PRNGKey(0), jnp.ones((batch_size, max_decoder_length, hidden_size)), deterministic=True)
params = init_variables["params"]
cache = init_variables["cache"]
dummy_prompt = jnp.arange(batch_size * prompt_length * hidden_size).reshape((batch_size, prompt_length, hidden_size))
output, cache = attn_layer.apply({"params": params, "cache": cache}, dummy_prompt, mutable=["cache"], deterministic=True)
leads to an error. Also check this notebook.
What you expected to happen:
Instead, the code should work and the first len(prompt_length)
cache variables should be stored.
Logs, error messages, etc:
~/python_bin/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
273 _context.module_stack.append(self)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
277 filter_fn = _context.capture_stack[-1]
~/python_bin/flax/linen/attention.py in __call__(self, inputs_q, inputs_kv, mask, deterministic)
265 expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
266 if expected_shape != query.shape:
--> 267 raise ValueError('Autoregressive cache shape error, '
268 'expected query shape %s instead got %s.' %
269 (expected_shape, query.shape))
ValueError: Autoregressive cache shape error, expected query shape (1, 1, 1, 2) instead got (1, 2, 1, 2).
Steps to reproduce:
Whenever possible, please provide a minimal example. Please consider submitting it as a Colab link.
See code/colab above
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
Summary of the models - Hugging Face
Multimodal models mix text inputs with other kinds (e.g. images) and are more specific to a given task. Decoders or autoregressive models. As...
Read more >Decoding As Dynamic Programming For Recurrent ...
Comment: This paper proposes an approximate inference approach for decoding in autoregressive models, based on the method of auxiliary coordinates, which uses ...
Read more >Autoregressive Autoencoders | Bounded Rationality
It's an "autoencoder" because it's using the same value x value on the input and output. Figure 1 shows a picture of what...
Read more >A Survey on Non-Autoregressive Generation for Neural ...
Figure 1: Outline of the survey. We first review the developments of neural machine translation and non-autoregressive related methods.
Read more >A Survey on Non-Autoregressive Generation for Neural ... - arXiv
1. Outline of the survey. We first review the developments of neural machine translation and non-autoregressive related methods. Then we present ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hey! I commented on #1316 but hadn’t seen these comments when writing that. So I think it sounds like you want to do a single-pass “teacher-forced” cache-initialization. We’d need to add a third “operating mode” for the attention layer (basically, run normally but then stuff the keys and values for the first N tokens into a cache).
Apologies if I am misunderstanding, but for fast decoding we deliberately only feed 1 token at a time and assume the caller iterates over the input and maintains the cache. So you could do something like this in your code:
Are you saying that this approach doesn’t work for your use case? If so, could you please explain why not?