Does the 'bad_words_ids' argument in the "generate function" works?
See original GitHub issueEnvironment info
transformers
version: 4.12.0- Platform: Linux-5.4.104±x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.12
- PyTorch version (GPU?): 1.9.0+cu111 (False)
- Tensorflow version (GPU?): 2.6.0 (False)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help
Information
I attempt to evaluate whether the bad_words_ids
argument that available in the generate()
function works or not. However, based on the steps that I described in below section, it doesn’t works.
To reproduce
Below is the steps I used to evaluate:
- Run the script without
bad_words_ids
being specified andset_seed
to get deterministic output.
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, set_seed
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
set_seed(0)
input_context = "My cute dog"
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True)
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:
Generated: My cute dog, when it died, had taken my entire life to save the life that had been
- Re-run the script, but with
bad_words_ids
being specified. I select the word “entire” and “save” taken from the previously generated sequence. However, both words still appear in the output sequence with no difference as the previous one. Below is the script with the following output.
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, set_seed
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
set_seed(0)
input_context = "My cute dog"
# get tokens of words that should not be generated
bad_words_ids = [tokenizer(bad_word).input_ids for bad_word in ["entire", "save"]]
# encode input context
input_ids = tokenizer(input_context, return_tensors="pt").input_ids
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=20, do_sample=True, bad_words_ids=bad_words_ids)
print("Generated:", tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:
Generated: My cute dog, when it died, had taken my entire life to save the life that had been
To reproduce in Google Colab:
https://colab.research.google.com/drive/1P4ruLhFstbal1qqXbjuv-kM7yMYY-S1E?usp=sharing
Expected behavior
I expect the word “entire” and “save” to not be included in the output sequence after I run step (2) in above section.
Issue Analytics
- State:
- Created 2 years ago
- Comments:13 (4 by maintainers)
Top Results From Across the Web
function* - JavaScript - MDN Web Docs - Mozilla
The function* declaration (function keyword followed by an asterisk) defines a generator function, which returns a Generator object.
Read more >What is the purpose of the "send" function on Python ...
Resumes the execution and “sends” a value into the generator function. The value argument becomes the result of the current yield expression.
Read more >What is a Python Generator? (Implementing Your Own range ...
A function takes arguments for its parameters, runs some code, and then returns some return value. Now let's take a look at a...
Read more >Arguments to Python generator functions - Karol Kuczmarski
In Python, a generator function is one that contains a yield ... functions and iterator classes are lazy: they only do work when...
Read more >How to Use Generators and yield in Python
You'll create generator functions and generator expressions using multiple ... A common use case of generators is to work with data streams or...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hey @alvinwatner,
To prevent bad words from occurring in the middle of generated texts, you’ll need to add a prefix space to every bad word so that the tokenized bad words e.g.
save
will be['Ġsave']
instead of['save']
, which matches GPT2’s outputs.This can be done by setting
add_prefix_space=True
in the kwargs offrom_pretrained
.Output:
Interesting, I just figure it out. For Chinese bart, you only need the one token id to make it work out, because there is no suffix in Chinese character, so if you use tokenizer to get bad word ids, it will return something like [[101, 704, 102]], but the 101 and 102 represent [CLS] and [SEP], you only need 704 id.