Can GPT2LMHeadModel do batch inference with variable sentence lengths?
See original GitHub issueGiven GPT2 tokenizer do not have an internal pad_token_id, how do I pad sentences and do batch inference using GPT2LMHeadModel? Specifically my code as:
prompt_text = [
'in this paper we',
'we are trying to',
'The purpose of this workshop is to check whether we can', ]
tokens = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(x, add_prefix_space=True)) for x in prompt_text]
inputs = pad_sequence([torch.LongTensor(x) for x in tokens], batch_first = True, padding_value=tokenizer.eos_token_id)
outputs, past = model(input_ids=inputs, attention_mask=None)
This will return non-relevant predictions since GPT2 will consider the eos_tokens and start a new sentence in the batch.
Can anyone please share sample codes that using GPT2LMHeadModel to do batch inference with various sentence lengths?
Thanks!
Issue Analytics
- State:
- Created 4 years ago
- Comments:41 (18 by maintainers)
Top Results From Across the Web
Handling multiple sequences - Hugging Face Course
How do we handle multiple sequences of different lengths? ... If you only have one sentence, you can just build a batch with...
Read more >Using past and attention_mask at the same time for gpt2
Using past and attention_mask at the same time for gpt2 ... I am processing a batch of sentences with different lengths, so I...
Read more >Word-level text generation using GPT-2, LSTM and Markov ...
It splits the text into consecutive blocks of certain length, e.g., it will cut the text every 1024 tokens. GPT2LMHeadModel is the GPT-2...
Read more >Text Generation: (Distil)GPT-2 - seekinginference
It is important to note that in training, we offset our batch by 1, yielding two 1024-length sequences (the maximum length for DistilGPT-2)....
Read more >QABot_seq2seq_model_using_tr...
The purpose of the encoder is to encode a variable length question sequence ... max lenth with this pad token to make variable...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I tried a rough version, basically adding attention mask to the padding positions and keep updating this mask as generation grows. One thing worth noting is that in the first step instead of extract the -1-th positions output for each sample, we need to keep track of the real prompt ending position, otherwise sometimes the output from padding positions will be extracted and produce random results.
Code snippet:
Also a minor change to
src/transformers/modeling_gpt2.py
:line 422:
attention_mask = attention_mask.view(-1, input_shape[-1])
change to
attention_mask = attention_mask.view(input_shape[0], -1)
(not sure if this change will break other things)
Output:
in this paper we have a very good idea of how to use the data to make predictions about the future. We
we are trying to get the best possible deal for the best price. We are not going to be able to offer
The purpose of this workshop is to check whether we can make a difference in the lives of people who are struggling with mental illness.
@schizism Concerning LM inference on batches of different lengths is actually a problem we are currently looking at. Ideally, you should be able to simple put your input_ids (and an attention_mask) to model.generate() to make it work.
@XinyuHua thanks for your great contribution to make LM inference work on batches having different lengths. Also it seems like you found a bug, when using the
past
andattention_mask
variables as an input in GPT2. That’s great! I will open a new issue for that and take a look 😃Below, I am adding a simplified code snippet using simpler tokenization functions. In this code, no
past
variable is used related to the bug found by @XinyuHua.