question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Can GPT2LMHeadModel do batch inference with variable sentence lengths?

See original GitHub issue

Given GPT2 tokenizer do not have an internal pad_token_id, how do I pad sentences and do batch inference using GPT2LMHeadModel? Specifically my code as:

prompt_text = [
    'in this paper we',
    'we are trying to',
    'The purpose of this workshop is to check whether we can', ]

tokens = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(x, add_prefix_space=True)) for x in prompt_text]

inputs = pad_sequence([torch.LongTensor(x) for x in tokens], batch_first = True, padding_value=tokenizer.eos_token_id)

outputs, past = model(input_ids=inputs, attention_mask=None)

This will return non-relevant predictions since GPT2 will consider the eos_tokens and start a new sentence in the batch.

Can anyone please share sample codes that using GPT2LMHeadModel to do batch inference with various sentence lengths?

Thanks!

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:41 (18 by maintainers)

github_iconTop GitHub Comments

13reactions
XinyuHuacommented, Feb 26, 2020

I tried a rough version, basically adding attention mask to the padding positions and keep updating this mask as generation grows. One thing worth noting is that in the first step instead of extract the -1-th positions output for each sample, we need to keep track of the real prompt ending position, otherwise sometimes the output from padding positions will be extracted and produce random results.

Code snippet:

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

prompt_text = [
    'in this paper we',
    'we are trying to',
    'The purpose of this workshop is to check whether we can', ]
batch_size = len(prompt_text)
max_length = 30
eos_token_id = tokenizer.eos_token_id

model = model.cuda()

token_ids = [tokenizer.encode(s, add_special_tokens=False) for s in prompt_text]
prompt_lengths = [len(s) for s in token_ids]
max_prompt_len = max(prompt_lengths)

# use 0 as padding id, shouldn't matter
padded_tokens = [tok_ids + [0] * (max_prompt_len - len(tok_ids)) for tok_ids in token_ids]
input_ids = torch.LongTensor(padded_tokens).cuda()
attn_mask = torch.zeros(input_ids.shape).long().cuda()
for ix, tok_ids in enumerate(token_ids):
    attn_mask[ix][:len(tok_ids)] = 1

unfinished_sents = input_ids.new(batch_size).fill_(1)
past = None
cur_len = input_ids.shape[1]

def post_processing(input_ids, attn_mask):
    """Remove padding tokens in the middle of the sequence."""
    input_ids_proc = []
    for ix, seq in enumerate(input_ids):
        input_ids_proc.append([tok_id for tok_id, mask in zip(seq, attn_mask[ix]) if mask != 0])
    return input_ids_proc


input_lengths_index = torch.tensor([x - 1 for x in prompt_lengths]).cuda()
input_lengths_index = input_lengths_index.view(-1, 1).repeat(1, 50257).unsqueeze(1)

while cur_len < max_length:
    model_inputs = model.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attn_mask)
    outputs = model(**model_inputs)
    if cur_len == max_prompt_len:
        # at first step we can't directly extract the -1-th position's
        # prediction for next word, since for some samples the -1-th
        # token is PAD. Instead we keep track of the real prompt ending.
        next_token_logits = outputs[0].gather(1, input_lengths_index).squeeze(1)
    else:
        next_token_logits = outputs[0][:, -1, :]
    past = outputs[1]
    next_token = torch.argmax(next_token_logits, dim=-1)
    tokens_to_add = next_token * unfinished_sents + 0 * (1 - unfinished_sents)
    input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
    attn_mask = torch.cat([attn_mask, torch.ones((batch_size, 1)).long().cuda()], dim=1)

    unfinished_sents.mul_(tokens_to_add.ne(eos_token_id).long())
    cur_len += 1

    if unfinished_sents.max() == 0:
        break

input_ids = post_processing(input_ids, attn_mask)
for item in input_ids:
    print(tokenizer.decode(item))

Also a minor change to src/transformers/modeling_gpt2.py:

line 422: attention_mask = attention_mask.view(-1, input_shape[-1])

change to attention_mask = attention_mask.view(input_shape[0], -1)

(not sure if this change will break other things)

Output:

in this paper we have a very good idea of how to use the data to make predictions about the future. We we are trying to get the best possible deal for the best price. We are not going to be able to offer The purpose of this workshop is to check whether we can make a difference in the lives of people who are struggling with mental illness.

8reactions
patrickvonplatencommented, Mar 12, 2020

@schizism Concerning LM inference on batches of different lengths is actually a problem we are currently looking at. Ideally, you should be able to simple put your input_ids (and an attention_mask) to model.generate() to make it work.

@XinyuHua thanks for your great contribution to make LM inference work on batches having different lengths. Also it seems like you found a bug, when using the past and attention_mask variables as an input in GPT2. That’s great! I will open a new issue for that and take a look 😃

Below, I am adding a simplified code snippet using simpler tokenization functions. In this code, no past variable is used related to the bug found by @XinyuHua.

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<PAD>')
# IMPORTANT: Note that setting the <PAD> token like this itn the constructor gives the
# pad_token the pad_token_id = 50256, which normally belongs to <BOS> token_ids in GPT2
# This is a very ugly way that works at the moment of setting the pad_token_id to the <BOS> token that is already included in the vocab size. This will be updated in the coming weeks! # noqa: E501

prompt_text = [
    'in this paper we',
    'we are trying to',
    'The purpose of this workshop is to check whether we can']

# encode plus batch handles multiple batches and automatically creates attention_masks
seq_len = 11
encodings_dict = tokenizer.batch_encode_plus(prompt_text, max_length=seq_len, pad_to_max_length=True)

# ideally we should be able to just input the following two variables to the function model.generate() ... => to be implemented soon!  # noqa: E501
input_ids = torch.tensor(encodings_dict['input_ids'])
attn_mask = torch.tensor(encodings_dict['attention_mask'])

num_tokens_to_produce = 20
pad_token_id = tokenizer.pad_token_id
eos_token_id = tokenizer.eos_token_id
eos_not_in_sents = torch.ones(input_ids.shape[0]).long()

# we need to get the token ids of the last non-padded value
last_non_masked_idx = torch.sum(attn_mask, dim=1) - 1
start_idx = inp_idx = (last_non_masked_idx).view(-1, 1).repeat(1, tokenizer.vocab_size).unsqueeze(1)
past = None

# get correct position ids
position_ids = torch.tensor([list(range(seq_len)) for i in range(input_ids.shape[0])])
for i, position_ids_slice in enumerate(position_ids):
    position_ids_slice[last_non_masked_idx[i]:] = position_ids_slice[last_non_masked_idx[i]]

for step in range(num_tokens_to_produce):
    outputs = model(input_ids, attention_mask=attn_mask, position_ids=position_ids)

    # in the first decoding step, we want to use the 'real' last position for each sentence
    if step == 0:
        next_token_logits = outputs[0].gather(1, start_idx).squeeze(1)
    else:
        next_token_logits = outputs[0][:, -1, :]

    next_tokens = torch.argmax(next_token_logits, dim=-1)

    # this updates which sentences have not seen an <EOS> token so far
    # if one <EOS> token was seen the sentence is finished
    eos_not_in_sents.mul_(next_tokens.ne(eos_token_id).long())

    # either append a padding token here if <EOS> has been seen or append next token
    tokens_to_add = next_tokens * (eos_not_in_sents) + pad_token_id * (1 - eos_not_in_sents)

    # Update input_ids, attn_mask and position_ids
    input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
    attn_mask = torch.cat([attn_mask, torch.ones((attn_mask.shape[0], 1)).long()], dim=1)
    position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=1)

[print(tokenizer.decode(output, skip_special_tokens=True)) for output in input_ids]
Read more comments on GitHub >

github_iconTop Results From Across the Web

Handling multiple sequences - Hugging Face Course
How do we handle multiple sequences of different lengths? ... If you only have one sentence, you can just build a batch with...
Read more >
Using past and attention_mask at the same time for gpt2
Using past and attention_mask at the same time for gpt2 ... I am processing a batch of sentences with different lengths, so I...
Read more >
Word-level text generation using GPT-2, LSTM and Markov ...
It splits the text into consecutive blocks of certain length, e.g., it will cut the text every 1024 tokens. GPT2LMHeadModel is the GPT-2...
Read more >
Text Generation: (Distil)GPT-2 - seekinginference
It is important to note that in training, we offset our batch by 1, yielding two 1024-length sequences (the maximum length for DistilGPT-2)....
Read more >
QABot_seq2seq_model_using_tr...
The purpose of the encoder is to encode a variable length question sequence ... max lenth with this pad token to make variable...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found