question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Trying to add support for GPT2 as decoder in EncoderDecoder model

See original GitHub issue

🚀 Feature request

Hi, I am trying to add the option of using GPT2 as the decoder in the EncoderDecoder model, which only support

Motivation

For a generation problem, it usually better to use GPT2 as the decoder, over BERT.

Your contribution

I’ve made the following changes in modeling_gpt2.py file:

  • Added crossattention layer if the model is a decoder, to the Block class:
class Block(nn.Module):
    def __init__(self, n_ctx, config, scale=False):
        super().__init__()
        nx = config.n_embd
        self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.attn = Attention(nx, n_ctx, config, scale)
        self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
        self.mlp = MLP(4 * nx, config)
        self.is_decoder = config.is_decoder
        if self.is_decoder:
            self.crossattention = Attention(nx, n_ctx, config, scale)
...
    def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None,
                encoder_attention_mask=None):
        output_attn = self.attn(
            self.ln_1(x),
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
        )
        a = output_attn[0]  # output_attn: a, present, (attentions)
        outputs = []
        if self.is_decoder and encoder_hidden_states is not None:
            cross_attention_outputs = self.crossattention(
                a, layer_past, attention_mask, head_mask, encoder_hidden_states=encoder_hidden_states,
                                            encoder_attention_mask=encoder_attention_mask
            )
            a = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights

        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m

        outputs = [x] + output_attn[1:] + outputs

        return outputs  # x, present, (attentions)
  • Added 3 Linear layers instead of the Conv1d layer:
class Attention(nn.Module):
    def __init__(self, nx, n_ctx, config, scale=False):
...
        # self.c_attn = Conv1D(n_state * 3, nx)
        self.query = nn.Linear(n_state, nx)
        self.key = nn.Linear(n_state, nx)
        self.value = nn.Linear(n_state, nx)
...
  • Added encoder_attention_mask and encoder_hidden_states to the forward function of the Attention class, and using them for the key and the value if they are provided:
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None,
                encoder_attention_mask=None):
        query = self.query(x)
        if encoder_hidden_states is not None:
            key = self.key(encoder_hidden_states)
            value = self.value(encoder_hidden_states)
            attention_mask = encoder_attention_mask
        else:
            key = self.key(x)
            value = self.value(x)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
...
  • Added the encoder_attention_mask and encoder_hidden_states arguments to the GPT2Model forward function, and processed encoder_attention_mask same as attention_mask:
class GPT2Model(GPT2PreTrainedModel):
...
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        use_cache=True,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
...
        # Encoder attention mask. (same action as for regular attention mask)
        if encoder_attention_mask is not None:
            assert batch_size > 0, "batch_size has to be defined and > 0"
            encoder_attention_mask = encoder_attention_mask.view(batch_size, -1)
            encoder_attention_mask = encoder_attention_mask.unsqueeze(1).unsqueeze(2)
            encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
            encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
...
        for i, (block, layer_past) in enumerate(zip(self.h, past)):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)

            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
            )
...
  • Added the encoder_attention_mask and encoder_hidden_states arguments to the GPT2LMHeadModelforward function, as well as lm_lables and masked_lm_labels for EncoderDecoder model compatibility (probably it’s better to use GPT2DoubleHeadsModel):
class GPT2LMHeadModel(GPT2PreTrainedModel):
...
    def forward(
        self,
        input_ids=None,
        past=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        use_cache=True,
        lm_labels=None,
        masked_lm_labels=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
    ):
...
        if lm_labels is not None:
            if labels is not None:
                raise ValueError("You cannot specify both labels and lm_labels at the same time")
            labels = lm_labels

        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )
...

My biggest concern is with the second bullet, and I wanted to ask you if this implementation seems right (for now it’s look like I am able to train and test an EncoderDecoder with BERT2GPT architecture). Of course that if needed, I can provide the full code to all of my changes, but all of my changes is listed above. Most (if not all) of the code I’ve add is adapted from huggingface modeling_bert.pyfile, so all of the credit goes to them.

Thanks

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:24 (11 by maintainers)

github_iconTop GitHub Comments

10reactions
patrickvonplatencommented, Aug 18, 2020

GPT2 is added and results on summariation look promising. Check out this model (Bert2GPT2 trained on CNN/Daily Mail) including train and eval script: https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16 .

7reactions
patrickvonplatencommented, Aug 12, 2020

Will finish the PR tomorrow then it should be pretty easy to do BERT2GPT2.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Leveraging Pre-trained Language Model Checkpoints for ...
In essence, an encoder-decoder model is the combination of a stand-alone encoder, such as BERT, and a stand-alone decoder model, such as GPT2....
Read more >
Exploring Language Models for Neural Machine Translation ...
The GPT-2 model was created by OpenAI, and is a transformer-based language model with 1.5 billion parameters trained on WebText. The model is...
Read more >
Why does GPT-2 Exclude the Transformer Encoder?
GPT-2 does not require the encoder part of the transformer architecture because the model uses a masked self-attention that can only look at ......
Read more >
NLP Model building[Transformers, Attention & more] - Kaggle
Further advancement would be made on attention based encoder-decoder modules like Transformers and using the different flavours from BERT to GPT.
Read more >
Train and Deploy Fine-Tuned GPT-2 Model Using PyTorch on ...
Text classification is a very common task in NLP. It can be used in many applications from spam filtering, sentiment analysis to customer...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found