question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

BertLMHeadModel (w/ relative position embedding) does not work correctly when use_cache = True

See original GitHub issue

System Info

  • transformers version: 4.20.1
  • Platform: Linux-5.4.0-92-generic-x86_64-with-glibc2.17
  • Python version: 3.8.13
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.12.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@LysandreJik

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

I found that BertLMHeadModel (w/ relative position embedding) sometimes generates unexpected sequences when use_cache = True.

Here is a minimal code sample that indirectly demonstrates this problem:

import torch
from transformers import BertConfig, BertLMHeadModel

config = BertConfig(
    is_decoder=True, vocab_size=10, hidden_size=64, num_hidden_layers=1, num_attention_heads=4,
    intermediate_size=64, position_embedding_type='relative_key')
model = BertLMHeadModel(config).eval()

with torch.no_grad():
    model.config.use_cache = False
    generation = model.generate(bos_token_id=1, max_length=5, output_attentions=True, return_dict_in_generate=True)
    print(generation.attentions[-1][0][:, :, -1:, :])

    prediction = model(input_ids=generation.sequences[:, :-1], output_attentions=True)
    print(prediction.attentions[0][:, :, -1:, :])

    model.config.use_cache = True
    generation = model.generate(bos_token_id=1, max_length=5, output_attentions=True, return_dict_in_generate=True)
    print(generation.attentions[-1][0])

Outputs:

tensor([[[[0.2455, 0.2530, 0.2558, 0.2457]],

         [[0.2495, 0.2492, 0.2497, 0.2516]],

         [[0.2481, 0.2516, 0.2514, 0.2489]],

         [[0.2496, 0.2538, 0.2533, 0.2433]]]])
tensor([[[[0.2455, 0.2530, 0.2558, 0.2457]],

         [[0.2495, 0.2492, 0.2497, 0.2516]],

         [[0.2481, 0.2516, 0.2514, 0.2489]],

         [[0.2496, 0.2538, 0.2533, 0.2433]]]])
tensor([[[[0.2452, 0.2532, 0.2548, 0.2468]],

         [[0.2498, 0.2492, 0.2494, 0.2516]],

         [[0.2485, 0.2516, 0.2516, 0.2483]],

         [[0.2492, 0.2538, 0.2528, 0.2442]]]])

Expected behavior

The three printed attention tensors must have the same values, but different values. (The generated sequences are all the same in this case, but as the model is trained, different sequences are generated according to use_cache.)

The cause of this problem is that BertSelfAttention’s relative position embedding does not handle use_cache = True case properly. It seems that this problem can be fixed by modifying BertSelfAttention’s forward function as follows:

# ...

use_cache = past_key_value is not None
if self.is_decoder:
    # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
    # Further calls to cross_attention layer can then reuse all cross-attention
    # key/value_states (first "if" case)
    # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
    # all previous decoder key/value_states. Further calls to uni-directional self-attention
    # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
    # if encoder bi-directional self-attention `past_key_value` is always `None`
    past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
    query_length, key_length = query_layer.shape[2], key_layer.shape[2]
    if use_cache:
        position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(-1, 1)
    else:
        position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
    position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
    distance = position_ids_l - position_ids_r

    # ...

(The current code always makes the distance variable become tensor([[0]]) when use_cache = True.)

Other models using the same code also need modifications…

Also, BertLMHeadModel’s generate function does not overwrite the use_cache option. It seems that BertLMHeadModel’s prepare_inputs_for_generation function should add use_cache item to the output dictionary similar to this.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:6 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
ArthurZuckercommented, Nov 14, 2022

Hey! So after investigating in detail, it seems that we indeed have problem, but the good new is that it is not a major issue.

First, we have to use a model that was trained with relative_key, so I used "zhiheng-huang/bert-base-uncased-embedding-relative-key".

  • The attention scores are indeed different, but the result of the softmax (the last logits are different) is always the same. This seem to come from the learned embedding that doesn’t seem to have a huge impact (when the model already has learned) but could impact the training.

Minimal reproducing script :

import torch
from transformers import BertTokenizer, BertLMHeadModel, set_seed
tokenizer = BertTokenizer.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key")
model = BertLMHeadModel.from_pretrained("zhiheng-huang/bert-base-uncased-embedding-relative-key", is_decoder = True)
inputs = tokenizer("No I'm not missing the ", return_tensors="pt")
input_ids = inputs.input_ids[:,:-1]
attention_mask = inputs.attention_mask[:,:-1]

with torch.no_grad():
    model.config.use_cache = False
    set_seed(0)
    output = model(input_ids, attention_mask = attention_mask, use_cache =False)
    print(output.logits[:,-1,:])

    model.config.use_cache = True
    output_1 = model(input_ids[:,:-1], use_cache = True, attention_mask = attention_mask[:,:-1])
    pkv = output_1.past_key_values
    output_2 = model(input_ids[:,-1:], past_key_values = pkv , use_cache = True)
    print(output_2.logits[:,-1,:])

tensor([[-5.4971, -6.4888, -8.3359,  ..., -7.3612, -5.5480, -0.9784]])
tensor([[ -7.2693,  -7.7799, -10.0905,  ...,  -7.5183,  -7.4255,  -4.6804]])

With your fix we indeed have

tensor([[-5.4971, -6.4888, -8.3359,  ..., -7.3612, -5.5480, -0.9784]])
tensor([[-5.4971, -6.4888, -8.3359,  ..., -7.3612, -5.5480, -0.9784]])

This should have been tested when merging the model, but it seems like it was not. I will open a PR to address this.

0reactions
ArthurZuckercommented, Oct 16, 2022

Hey, I am currently investigating whether we should indeed change the attention or not. As a lot of models depend from it, I wanna make sure this would be backward compatible! But if you want , feel free to open a PR. 😄

Read more comments on GitHub >

github_iconTop Results From Across the Web

BERT — transformers 4.7.0 documentation - Hugging Face
Tips: BERT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the...
Read more >
Improve Transformer Models with Better Relative Position ...
We demonstrate empirically that our relative position embedding method is reasonably generalized and robust from the inductive perspective.
Read more >
Relative Positional Encoding - Jake Tae
In this post, we will take a look at relative positional encoding, as introduced in Shaw et al (2018) and refined by Huang...
Read more >
Master Positional Encoding: Part II | by Jonathan Kernes
This imbues position with a well defined meaning: position is always relative. Since we are trying to build machines to understand human logic, ......
Read more >
On Position Embeddings in BERT - OpenReview
Various Position Embeddings (PEs) have been proposed in Transformer based ... which I am sure was not the intention. nit: “since relative distance...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found