question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Does the 'bad_words_ids' argument in the "generate function" works?

See original GitHub issue

Environment info

  • transformers version: 4.12.0
  • Platform: Linux-5.4.104±x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.7.12
  • PyTorch version (GPU?): 1.9.0+cu111 (False)
  • Tensorflow version (GPU?): 2.6.0 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

Information

I attempt to evaluate whether the bad_words_ids argument that available in the generate() function works or not. However, based on the steps that I described in below section, it doesn’t works.

To reproduce

Below is the steps I used to evaluate:

  1. Run the script without bad_words_ids being specified and set_seed to get deterministic output.
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, set_seed

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

set_seed(0)

input_context = "My cute dog"
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True)
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

Output: Generated: My cute dog, when it died, had taken my entire life to save the life that had been

  1. Re-run the script, but with bad_words_ids being specified. I select the word “entire” and “save” taken from the previously generated sequence. However, both words still appear in the output sequence with no difference as the previous one. Below is the script with the following output.
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, set_seed

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

set_seed(0)

input_context = "My cute dog"
# get tokens of words that should not be generated
bad_words_ids = [tokenizer(bad_word).input_ids for bad_word in ["entire", "save"]]
# encode input context
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))

Output: Generated: My cute dog, when it died, had taken my entire life to save the life that had been

To reproduce in Google Colab:

https://colab.research.google.com/drive/1P4ruLhFstbal1qqXbjuv-kM7yMYY-S1E?usp=sharing

Expected behavior

I expect the word “entire” and “save” to not be included in the output sequence after I run step (2) in above section.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:13 (4 by maintainers)

github_iconTop GitHub Comments

3reactions
qqaatwcommented, Oct 30, 2021

Hey @alvinwatner,

To prevent bad words from occurring in the middle of generated texts, you’ll need to add a prefix space to every bad word so that the tokenized bad words e.g. save will be ['Ġsave'] instead of ['save'], which matches GPT2’s outputs.

This can be done by setting add_prefix_space=True in the kwargs of from_pretrained.

model = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True)

set_seed(0)

input_context = "My cute dog"
# get tokens of words that should not be generated
bad_words_ids = tokenizer(["entire", "save"]).input_ids
# encode input context
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
print("Generated:", tokenizer.decode(outputs["sequences"][0], skip_special_tokens=True))

Output:

Generated:  My cute dog, when it died, had taken my hand out of my pants and said "I
0reactions
musitafa0032commented, Apr 4, 2022

Interesting, I just figure it out. For Chinese bart, you only need the one token id to make it work out, because there is no suffix in Chinese character, so if you use tokenizer to get bad word ids, it will return something like [[101, 704, 102]], but the 101 and 102 represent [CLS] and [SEP], you only need 704 id.

Read more comments on GitHub >

github_iconTop Results From Across the Web

function* - JavaScript - MDN Web Docs - Mozilla
The function* declaration (function keyword followed by an asterisk) defines a generator function, which returns a Generator object.
Read more >
What is the purpose of the "send" function on Python ...
Resumes the execution and “sends” a value into the generator function. The value argument becomes the result of the current yield expression.
Read more >
What is a Python Generator? (Implementing Your Own range ...
A function takes arguments for its parameters, runs some code, and then returns some return value. Now let's take a look at a...
Read more >
Arguments to Python generator functions - Karol Kuczmarski
In Python, a generator function is one that contains a yield ... functions and iterator classes are lazy: they only do work when...
Read more >
How to Use Generators and yield in Python
You'll create generator functions and generator expressions using multiple ... A common use case of generators is to work with data streams or...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found