question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Adding State-of-the-art Contrastive Search to the Codebase of model.generate()

See original GitHub issue

Feature request


<span id='all_catelogue'/>

Catalogue:


<span id='abstract'/>

1. Abstract: [Back to Top]

In this issue, we try to integrate contrastive search into the codebase of model.generate() as an additional option for text generation. We believe it would greatly benefit the research community.

All related resources of our work have been open-sourced, please check them as below.


<span id='introduction'/>

2. Introduction: [Back to Top]

Open-ended text generation is one core task in NLP. However, the maximization-based decoding methods (e.g., greedy search and beam search) of neural language models often lead to degenerate problems, i.e., the generated text is unnatural and contains undesirable repetitions. Existing approaches address the text degeneration problem by introducing stochasticity via sampling (e.g. top-k sampling [1] and nucleus sampling [2]), but they often lead to solutions that lack coherence.

In our recent NeurIPS 2022 paper [3], “A Contrastive Framework for Neural Text Generation”, we propose a new decoding method, i.e. contrastive search, which can be directly applied to all families of off-the-shelf language models (e.g. GPT and OPT). Specifically, during the decoding process, contrastive search selects from the most probable candidates predicted by the model while taking into account the degeneration penalty computed from the previous context. Formally, at each decoding step, given the context $\boldsymbol{x}_{< t}$, the selection of the output token $\boldsymbol{x}_t$ follows:

2022-09-23-10-45-28-image

where $V^{(k)}$ is the set of top-k predictions from the model’s probability distribution. The second term, degeneration penalty, measures how discriminative of candidate $v$ with respect to tokens in the previous context. And $s(h_v,h_{x_j})$ is the cosine similairty between the representations of the candidate $v$ and the previous token ${x_j}$. (The core implementations of contrastive search can be found in Section 5.)

Our proposed contrastive search can generate text that is coherent to the prefix text as well as can maintain the diversity of its generated result. Through extensive experiments, we demonstrate the clear superiority of contrastive search over all existing methods on both automatic (e.g. MAUVE [4]) and human evaluations.


<span id='demonstration'/>

3. Demonstration of the Awesome Results from Contrastive Search: [Back to Top]

In this section, we provide direct demonstrations to show how awesome contrastive search is. To this end, we select two representative models to demonstrate the results, including: (i) OPT [5], the recently released model by Meta which successfully replicates the results of GPT-3 [6]; (ii) GPT-2 [7], (arguably) the most widely used model for text generation.

<span id='opt_demonstration'/>

3.1. Demonstration with OPT: [Back to Top]

First, let’s see the results of OPT models.

Take the conversation generation task from the original paper [5] (Figure 9 at the Appendix E) as an example. The prefix text is shown as below.

A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?

(i) The generated text with OPT-175B + nucleus sampling, as shown in Figure 9 at the Appendix E from the original paper, is as following:

A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?
Statue: I have been there since 1886.
Human: What is your job?
Statue: I was built to welcome immigrants to the United States.
Human: What is your favorite food?
Statue: I like ice cream and apples.
Human: What is your favorite color?
Statue: I like red, white, and blue.
Human: What is your favorite animal?
Statue: I like dogs.
Human: What is your favorite song?

Obviously, even with 175B parameters, the model gets stuck in some simple and linguistically repetitive generations with nucleus sampling.

(ii) Now, let’s what result we can get using contrastive search with OPT-6.7B [model card] (a more than 20 times smaller model compared to OPT-175B).

[Reproduce Results] In Section 4.2, we show how to reproduce our results with a few lines of code.

A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?
Statue: Since 1884.
Human: Why did you come to America?
Statue: I was given to the United States by France as a gift for helping the French during the Franco-Prussian War.
Human: What do you think of America?
Statue: I love it. It is the greatest country in the world.
Human: What’s the weather like in New York?
Statue: It is cold.
Human: Is it safe to walk around at night?
Statue: Yes. There are policemen everywhere.
Human: Do you have any children?
Statue: Not yet. My pedestal is empty.
Human: What would you like to say to people who want to immigrate to America?
Statue: Come on over. You will be happy here. We have everything you need.

We see that the generated text with contrastive search is much more diverse and interesting.

[Comparison] For a more direct comparison, the generated results with the same OPT-6.7B using greedy search and nucleus sampling are:

(1) Text generated by greedy search: [click to expand]
----------------------------------------------------------------------------------------------------
A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?
Statue: I have lived here for over 100 years.
Human: What do you do?
Statue: I welcome people from all over the world to come to America.
Human: What do you think of America?
Statue: I love America.
Human: What do you think of immigrants?
Statue: I love immigrants.
Human: What do you think of America?
Statue: I love America.
Human: What do you think of immigrants?
Statue: I love immigrants.
Human: What do you think of America?
Statue: I love America.
Human: What do you think of immigrants?
Statue: I love immigrants.
Human: What do you think of America?
Statue: I love America.
Human: What do you think of immigrants?
Statue: I love immigrants.
Human: What do you think of America?
Statue: I love America.
Human: What do you think of immigrants?
Statue: I love immigrants.
Human: What do you think of America?
Statue: I love America.
Human: What do you think of immigrants?
Statue: I love immigrants.
Human...
----------------------------------------------------------------------------------------------------
(2) Text generated by nucleus sampling: [click to expand]
----------------------------------------------------------------------------------------------------
A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?
Statue: Since 1876.
Human: Why is the Statue of Liberty guarded?
Statue: Because there are many people trying to steal her.

a comparison about an unexpressed thought

I would also share the story of “A Humble Fear.” At a conference in New York the Dalai Lama gave a 
speech to the International Thinkers Congress in New York. The whole thing was recorded, and the 
video is quite interesting. (on a side note, I love the fact that there were some people who laughed
when he described himself as a humble being… I think the video is hilarious, there is a reason why
I put up the video. Because if you cannot find the humor in this you’re sadly lacking…)

In the speech, the Dalai Lama compares the search for truth to searching for treasure. He says: 
“However there is a huge difference between being a thief and a collector. A thief simply takes things, 
whereas a collector looks for the beauty, even if it is just a single object.”

The above quote is perhaps the most cliched Buddhist philosophy of our times. However the comparison
between a collector and a thief is quite interesting. I like to think that the Buddha...
----------------------------------------------------------------------------------------------------

We see that (i) greedy search generates repetitive text; and (ii) nucleus sampling produces text that is incoherent.

<span id='gpt_demonstration'/>

3.2. Demonstration with GPT: [Back to Top]

Next, let’s see the results of GPT models.

We provide a simple prefix text (DeepMind Company is) with only three words and asks the model to generate a long text with 512 tokens. In this example, we use GPT-2-large [model card] for text generation.

[Reproduce Results] In Section 4.3, we show how to reproduce our results with a few lines of code.

(1) Generated result with contrastive search:

----------------------------------------------------------------------------------------------------
DeepMind Company is a leader in artificial intelligence (AI). We have a long history of working with 
companies such as Google, Facebook, Amazon, and Microsoft to build products that improve people's lives, 
and today we are excited to announce that DeepMind's AlphaGo program has won the game of Go, becoming 
the first program to defeat a professional Go player.

The victory is a testament to the power of deep learning, and to the incredible work of our research team, 
which has been at the forefront of AI research for the past five years. AlphaGo is one of the most advanced 
Go programs ever created, and its performance is an important step towards the goal of human-level AI.

"This is the culmination of a decade of hard work," said Andy Ng, co-founder and CTO of DeepMind. "We are 
thrilled to have achieved this milestone and look forward to continuing to develop AI that can be used in 
a wide range of applications and to help people live better lives."

DeepMind's work on Go began in 2010, when it began to train a neural network to play Go using millions of 
games played by top Go players around the world. Since then, the team has refined the algorithm, adding 
more and more layers of reinforcement learning to make it better at recognizing patterns and making decisions 
based on those patterns. In the past year and a half, the team has made significant progress in the game, 
winning a record-tying 13 games in a row to move into the top four of the world rankings.

"The game of Go is a complex game in which players have to be very careful not to overextend their territory, 
and this is something that we have been able to improve over and over again," said Dr. Demis Hassabis, co-founder
and Chief Scientific Officer of DeepMind. "We are very proud of our team's work, and we hope that it will inspire
others to take the next step in their research and apply the same techniques to other problems."

In addition to the win in Go, DeepMind has also developed an AI system that can learn to play a number of different
games, including poker, Go, and chess. This AI system, called Tarsier, was developed in partnership with Carnegie
Mellon University and the University of California, Berkeley, and is being used to teach computer vision and machine
learning to identify objects in images and recognize speech in natural language. Tarsier has been trained to play
the game of Go and other games on a number of different platforms...
----------------------------------------------------------------------------------------------------

From the results, we can see that the entire generated document is very high-quality and human-like.

[Comparison] For a more direct comparison, the generated results with the same model using greedy search and nucleus sampling are:

(2) Text generated by greedy search: [click to expand]
----------------------------------------------------------------------------------------------------
DeepMind Company is a leading AI research company, with a focus on deep learning and deep learning-based systems.

The company's research is focused on the development of deep learning-based systems that can learn from large 
amounts of data, and that can be used to solve real-world problems.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's research is also used by the UK government to develop new technologies for the UK's National Health Service.

DeepMind's
----------------------------------------------------------------------------------------------------
(3) Text generated by nucleus sampling: [click to expand]
----------------------------------------------------------------------------------------------------
DeepMind Company is a Cardiff-based start-up with an exclusive mission to build the world's largest 
ever deep-learning system to analyse the world's digital content and in particular, super-sized image
content.
  
The system, the largest in the world with no previous expertise in image or digital content detection,
will have previously relied on a mixture of machine learning, artificial neural networks, and storage,
processing and retrieval techniques.
  
The AI system, called ImageNet, will take new approach to our challenge of data science and machine
learning, significantly improving efficiency, natural language processing and full understanding of 
complex, high-dimensional images, with an Eye of the Tiger framework for extracting techniques to 
ensure correct detection of particular images in complex scenes.
 
Dr. Mark Ward, Dr. Alex Kudle, Dr. Ralph Pinchbeck and CTO, DeepMind Dr. Alex Kudle
  
Case Study: Derpy's Most Wanted: Fighting Cybersecurity, building a robot-aided smuggling network
  
InfoSec News, 06/07/2017
  
Dimitrios Papadimitriou (left) and Chris Bardy (right) at G+ XE, July 2017
  
How to model an industrial malware botnet
  
In this case study, we show how to build a deep-learning environment to model a new, massive ransomware
botnet. Our model computes the distribution of user credentials stored on infected machines and produces
a toolkit for open-source "modeling-as-code" (MATC) simulation. We elaborate on the resource management
aspect of the toolkit, and how it can be adapted to working offline on embedded or cloud-based networks.
  
Hacking Networked: The industrial botnets of the future
  
InfoSec News, 04/11/2017
  
Intensive analysis of state sponsored malicious cyber activity, published by KBB Strategic
  
The major single source of IoT malware networks in 2017
  
The global commercial botnet equivalent count grew to 31.5% in 2017, up from 21.1% the year before, 
according to a comprehensive report from the Government Accountability Office (GAO). According to the 
report, various malware operators continued to convert massive amounts of wasted data into profits as
well as enable sophisticated cyber operations targeting critical infrastructure.
  
Industrial malware blasts up to 31\% of malware within the IP space over 2017...
----------------------------------------------------------------------------------------------------

Obviously, greedy search generates repetitive text while nucleus sampling produces text that is incoherent and quickly goes off-the-topic.


<span id='example_usage'/>

4. Example Usage: [Back to Top]

In our [main repo], we have provided detailed huggingface-style tutorials ([tutorial 1], [tutorial 2]) on how to apply contrastive search on different models across different languages.

In the following, we show how to easily reproduce our results in Section 3 with a few lines of code.

<span id='installation'/>

4.1. Environment Setup:

For an easy usage, we have provided a Pypi package which can be installed as below. More details of our package can be found [here].

pip install simctg --upgrade
<span id='reproduce_opt'/>

4.2. Reproduce Results of OPT:

To reproduce our results in Section 3.1 using OPT, (i) We first load the OPT model as

import torch
from simctg.simctgopt import SimCTGOPT
model_name = 'facebook/opt-6.7b'
model = SimCTGOPT(model_name)
tokenizer = model.tokenizer
model.eval()
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id

(ii) Then, we provide the prefix text as

prefix_text = r"""A chat between a curious human and the Statue of Liberty.

Human: What is your name?
Statue: I am the Statue of Liberty.
Human: Where do you live?
Statue: New York City.
Human: How long have you lived there?"""

(iii) Thirdly, we prepare the input ids as

[Important Tip] As the authors suggested in their [tutorial], OPT adds the EOS token to the beginning of every prompt. So make sure the special token is added at the front of the prompt.

tokens = tokenizer.tokenize(prefix_text)
input_ids = [bos_token_id] + tokenizer.convert_tokens_to_ids(tokens) # adds </s> to the beginning of every prompt
input_ids = torch.LongTensor(input_ids).view(1,-1)

(iv) Lastly, we generate the text with contrastive search as

beam_width, alpha, decoding_len = 5, 0.6, 256
output = model.fast_contrastive_search(input_ids=input_ids, beam_width=beam_width, 
                                       alpha=alpha, decoding_len=decoding_len,
                                       end_of_sequence_token_id = eos_token_id, early_stop = True) 
print("Output:\n" + 100 * '-')
print(tokenizer.decode(output[1:]))
print("" + 100 * '-')
<span id='reproduce_gpt'/>

4.3. Reproduce Results of GPT:

To reproduce our results in Section 3.2 using GPT, (i) We first load the GPT-2 model as

import torch
from simctg.simctggpt import SimCTGGPT
model_name = r'gpt2-large'
model = SimCTGGPT(model_name)
model.eval()
tokenizer = model.tokenizer
eos_token_id = tokenizer.eos_token_id

(ii) Then, we prepare the prefix text as

prefix_text = r"DeepMind Company is"
tokens = tokenizer.tokenize(prefix_text)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)

(iii) Last, we generate the text with contrastive search as

beam_width, alpha, decoding_len = 4, 0.6, 512
output = model.fast_contrastive_search(input_ids=input_ids, beam_width=beam_width, 
                                       alpha=alpha, decoding_len=decoding_len,
                                      end_of_sequence_token_id = eos_token_id, early_stop = True) 
print("Output:\n" + 100 * '-')
print(tokenizer.decode(output))
print("" + 100 * '-')

<span id='code_snippet'/>

5. Code Snippet: [Back to Top]

The main implemetations of contrastive search involves two parts: (i) candidates collection; and (ii) candidate re-ranking.

For more details, please find our open-sourced implementations for [GPT-2 models] and [OPT models].

(i) The collection of candidates can be implemented as below:

def ContrastiveSearchOneStep(model, input_ids, beam_width, alpha):
    '''
        model: the generation model, e.g., gpt2
        input_ids: 1 x seqlen
    '''
    prev_hidden_states, logits = model.compute_logits_and_hidden_states(input_ids)
    _, seqlen, embed_dim = prev_hidden_states.size()
    _, _, vocab_size = logits.size()
    p = random.uniform(0, 1)
  
    logit_for_next_step = logits[:,-1,:]
    assert logit_for_next_step.size() == torch.Size([1, vocab_size])
  
    next_probs = F.softmax(logit_for_next_step, dim = -1)
    assert next_probs.size() == logit_for_next_step.size()
  
    _, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
    assert top_k_ids.size() == torch.Size([1, beam_width])
  
    top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids)
  
    assert top_k_probs.size() == top_k_ids.size()
    # compute new hidden 
    expanded_context = [input_ids for _ in range(beam_width)]
    expanded_context = torch.cat(expanded_context, dim = 0)
    assert expanded_context.size() == torch.Size([beam_width, seqlen])
    top_k_ids = top_k_ids.view(beam_width, 1)
    next_input_ids = torch.cat([expanded_context, top_k_ids], dim = -1)
    assert next_input_ids.size() == torch.Size([beam_width, seqlen+1])
    new_hidden_states, next_logits = model.compute_logits_and_hidden_states(next_input_ids)
    assert new_hidden_states.size() == torch.Size([beam_width, seqlen+1, embed_dim])
    context_hidden = new_hidden_states[:,:seqlen,:]
    assert context_hidden.size() == torch.Size([beam_width, seqlen, embed_dim])
    next_hidden = new_hidden_states[:,seqlen:,:]
    assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
  
    next_id = ranking(context_hidden, next_hidden, top_k_ids, top_k_probs, alpha)       
  
    next_input_ids = torch.cat([input_ids, next_id], dim = -1)
    assert next_input_ids.size() == torch.Size([1, seqlen+1])
    return next_input_ids

(ii) The re-ranking of candidates can be implemented as below:

def ranking(context_hidden, next_hidden, next_top_k_ids, next_top_k_probs, alpha):
    '''
        context_hidden: beam_width x context_len x embed_dim
        next_hidden: beam_width x 1 x embed_dim
        next_top_k_ids: beam_width x 1
    '''
    beam_width, context_len, embed_dim = context_hidden.size()
    assert next_hidden.size() == torch.Size([beam_width, 1, embed_dim])
    norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
    norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
    cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
    assert cosine_matrix.size() == torch.Size([beam_width, context_len])
    scores, _ = torch.max(cosine_matrix, dim = -1)
    assert scores.size() == torch.Size([beam_width])
    next_top_k_probs = next_top_k_probs.view(-1)
    scores = (1.0 - alpha) * next_top_k_probs - alpha * scores 
    _, selected_idx = torch.topk(scores, k = 1)
    assert selected_idx.size() == torch.Size([1])
    selected_idx = selected_idx.unsqueeze(0)
    assert selected_idx.size() == torch.Size([1,1])
    next_id = torch.gather(next_top_k_ids, dim = 0, index=selected_idx)
    assert next_id.size() == torch.Size([1,1])
    return next_id

<span id='inference_latency'/>

6. Inference Latency: [Back to Top]

Lastly, we compare the inference latency of contrastive search with other widely used decoding methods. The results are shown in the Figure below.

2022-09-23-10-42-01-image

We see that the inference latency of contrastive search is comparable with other widely used methods, which further verifies the practical usage of our proposed approach.


<span id='reference'/>

References:

[1] Fan et al., 2018, “Hierarchical Neural Story Generation”, ACL 2018

[2] Holtzman et al., 2020, “The Curious Case of Neural Text Degeneration”, ICLR 2020

[3] Su et al., 2022, “A Contrastive Framework for Neural Text Generation”, NeurIPS 2022

[4] Pillutla et al., 2021, “MAUVE: Measuring the Gap Between Neural Text and Human Text using Divergence Frontiers”, NeurIPS 2021

[5] Zhang et al., 2022, “OPT: Open Pre-trained Transformer Language Models”, Arxiv 2022

[6] Brown et al., 2020, “Language Models are Few-Shot Learners”, NeurIPS 2020

[7] Radford et al., 2018, “Language Models are Unsupervised Multitask Learners”

Motivation

Given the exceptional performances of contrastive search, we certainly believe that it would greatly benefit a wide range of NLP researchers/practitioners in the text generation community.

Your contribution

I can submit a PR for this request feature ASAP.

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:4
  • Comments:20 (11 by maintainers)

github_iconTop GitHub Comments

2reactions
gantecommented, Nov 1, 2022

@yxuansu @stas00 actually it works for nearly all models, except for Bloom (which has a different shape for the past key values output) – working on it 😃

2reactions
gmftbyGMFTBYcommented, Oct 3, 2022

@gante Hi, thank you so much for your suggestions, we’ve almost prepared the PyTorch version codebase of contrastive_search in our fork. I have sent you an invitation to our repo.

All the changes are in src/tranformer/generation_utils.py and you could check them. Furthermore, we also prepare the test script for you to run the contrastive_search simply. To run this test scripts, please conduct the following commands:

cd tests/generation
CUDA_VISIBLE_DEVICES=0 python test_generation_contrastive_search.py

Looking forward to your valuable questions and suggestions.

Best, TianLan

Read more comments on GitHub >

github_iconTop Results From Across the Web

Generating Human-level Text with Contrastive Search in ...
In this blog, we introduce the current state-of-the-art decoding method, Contrastive Search, for neural text generation. Contrastive search ...
Read more >
Transformers: State-of-the-Art Natural Language Processing
Contrastive search decoding is a new state-of-the-art generation method which aims at reducing the repetitive patterns in which generation ...
Read more >
Search all Publications on Machine Learning for Source Code
We evaluate several state-of-the-art neural code intelligence models and benchmarks based on Java, Python, and Ruby codebases.
Read more >
Why do We Need Large Batchsizes in Contrastive Learning ...
We call our model Decomsable Contrastive Learning (DeCL). To learn from the ... loss.backward() ... we find the normalization can make the impact...
Read more >
Fastai: A Layered API for Deep Learning - MDPI
This way, a user wanting to rewrite part of the high-level API or add particular ... A user can create and train a...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found