Trying to add support for GPT2 as decoder in EncoderDecoder model
See original GitHub issue🚀 Feature request
Hi, I am trying to add the option of using GPT2 as the decoder in the EncoderDecoder model, which only support
Motivation
For a generation problem, it usually better to use GPT2 as the decoder, over BERT.
Your contribution
I’ve made the following changes in modeling_gpt2.py
file:
- Added crossattention layer if the model is a decoder, to the
Block
class:
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super().__init__()
nx = config.n_embd
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
self.is_decoder = config.is_decoder
if self.is_decoder:
self.crossattention = Attention(nx, n_ctx, config, scale)
...
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None,
encoder_attention_mask=None):
output_attn = self.attn(
self.ln_1(x),
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
a = output_attn[0] # output_attn: a, present, (attentions)
outputs = []
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
a, layer_past, attention_mask, head_mask, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask
)
a = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
outputs = [x] + output_attn[1:] + outputs
return outputs # x, present, (attentions)
- Added 3 Linear layers instead of the Conv1d layer:
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
...
# self.c_attn = Conv1D(n_state * 3, nx)
self.query = nn.Linear(n_state, nx)
self.key = nn.Linear(n_state, nx)
self.value = nn.Linear(n_state, nx)
...
- Added
encoder_attention_mask
andencoder_hidden_states
to the forward function of theAttention
class, and using them for the key and the value if they are provided:
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, encoder_hidden_states=None,
encoder_attention_mask=None):
query = self.query(x)
if encoder_hidden_states is not None:
key = self.key(encoder_hidden_states)
value = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
key = self.key(x)
value = self.value(x)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
...
- Added the
encoder_attention_mask
andencoder_hidden_states
arguments to theGPT2Model
forward function, and processedencoder_attention_mask
same as attention_mask:
class GPT2Model(GPT2PreTrainedModel):
...
def forward(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
...
# Encoder attention mask. (same action as for regular attention mask)
if encoder_attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
encoder_attention_mask = encoder_attention_mask.view(batch_size, -1)
encoder_attention_mask = encoder_attention_mask.unsqueeze(1).unsqueeze(2)
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
...
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
...
- Added the
encoder_attention_mask
andencoder_hidden_states
arguments to theGPT2LMHeadModel
forward function, as well aslm_lables
andmasked_lm_labels
for EncoderDecoder model compatibility (probably it’s better to useGPT2DoubleHeadsModel
):
class GPT2LMHeadModel(GPT2PreTrainedModel):
...
def forward(
self,
input_ids=None,
past=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
lm_labels=None,
masked_lm_labels=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
):
...
if lm_labels is not None:
if labels is not None:
raise ValueError("You cannot specify both labels and lm_labels at the same time")
labels = lm_labels
transformer_outputs = self.transformer(
input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
...
My biggest concern is with the second bullet, and I wanted to ask you if this implementation seems right (for now it’s look like I am able to train and test an EncoderDecoder with BERT2GPT architecture).
Of course that if needed, I can provide the full code to all of my changes, but all of my changes is listed above.
Most (if not all) of the code I’ve add is adapted from huggingface modeling_bert.py
file, so all of the credit goes to them.
Thanks
Issue Analytics
- State:
- Created 3 years ago
- Comments:24 (11 by maintainers)
Top Results From Across the Web
Leveraging Pre-trained Language Model Checkpoints for ...
In essence, an encoder-decoder model is the combination of a stand-alone encoder, such as BERT, and a stand-alone decoder model, such as GPT2....
Read more >Exploring Language Models for Neural Machine Translation ...
The GPT-2 model was created by OpenAI, and is a transformer-based language model with 1.5 billion parameters trained on WebText. The model is...
Read more >Why does GPT-2 Exclude the Transformer Encoder?
GPT-2 does not require the encoder part of the transformer architecture because the model uses a masked self-attention that can only look at ......
Read more >NLP Model building[Transformers, Attention & more] - Kaggle
Further advancement would be made on attention based encoder-decoder modules like Transformers and using the different flavours from BERT to GPT.
Read more >Train and Deploy Fine-Tuned GPT-2 Model Using PyTorch on ...
Text classification is a very common task in NLP. It can be used in many applications from spam filtering, sentiment analysis to customer...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
GPT2 is added and results on summariation look promising. Check out this model (Bert2GPT2 trained on CNN/Daily Mail) including train and eval script: https://huggingface.co/patrickvonplaten/bert2gpt2-cnn_dailymail-fp16 .
Will finish the PR tomorrow then it should be pretty easy to do BERT2GPT2.