Whisper doesn't compute positional embeddings properly when given batches of prompt tokens
See original GitHub issueSystem Info
v4.25.1 on M1 Mac with python 3.8
Who can help?
@sanchit-gandhi @patrickvonplaten @anton-l
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
When we want to run Whisper generation for a batch of samples with different prompt lengths (prefix tokens given to the decoder), positional embeddings for the decoder are improperly computed. It assumes all sequences have the same past_key_values_length
, but this is not true in general.
Scenario:
decoder_input_ids = [50361, 45431, 2584, 28682, 13, 50258, 50257, 50257]
("<|startofprev|>Something completely irrelevant.<|startoftranscript|><|pad|><|pad|>"
)
model.generate(input_features, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask)
will not give the correct output because at the beginning of decoding, the pad tokens won’t be taken into account that the positional embedding will be off.
Expected behavior
Instead of tracking past_key_values_length
, it should use the attention mask to compute position ids. The current implementation is more based off of encoder-decoder architectures that would never do decoder prompting, but it should take more inspiration from decoder-only models to handle prompting. This is done for the Flax implementation in #20479
Issue Analytics
- State:
- Created 9 months ago
- Comments:5 (3 by maintainers)
Top GitHub Comments
cc @ArthurZucker
@hannan72’s issue is separate to what I’m describing. But yes, padding should always be
max_length
- the issue I’m describing arises as a result of pad tokens being added to shorter sequences in batches (and won’t raise any errors - it’s just that Whisper’s handling of multiple sequence lengths under the hood is flawed and would be fixed by computingposition_ids
based offattention_mask
).