Allow adding custom logits processors in the `generate` method
See original GitHub issue🚀 Feature request
Hello,
I’d like to request a new feature in the generate
method of the GenerationMixin
class from generation_utils
. Specifically, I’d like a feature that allows a user to pass custom LogitsProcessors by adding a new argument logit_processors: Optional[LogitsProcessorList] = None
to the generate
method.
Motivation
I’d like to run generation on a pre-trained model, and I’d like to modify its output logits according to my custom function before the search or sampling or whatever is used. I think that this could be a common use case for controlled natural generation because one often wants to implement some trivial restrictions over generated logits.
Here is an example of how this could be used:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, LogitsProcessor, LogitsProcessorList
class MyLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
something_useful()
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
logit_processors = LogitsProcessorList([MyLogitsProcessor()])
input_ids = tokenizer('This dog is cute', return_tensors='pt').input_ids
model.generate(input_ids=input_ids, logit_processors=logit_processors)
Your contribution
I have no experience in open source, but I can try to help if you need a hand. I think that the general approach to implementing this is to do the following:
- Add the
logit_processors: Optional[LogitsProcessorList] = None
argument to thegenerate
method, - Add the same argument to the
_get_logits_processor
method of GenerationMixin and add the custom logit processors after all the other logit processors are in place. - Pass the custom logits processors to every call of
_get_logits_processor
in thegenerate
method.
What do you think?
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (3 by maintainers)
Top GitHub Comments
I think it’s a very nice idea !.
The problem you mention @patrickvonplaten I think will be relevant mostly for power users (that want to add a LogitsProcessor) so they should be careful in terms of how they use this tool. I guess we could emphasis this in the documentation for the
generate
function, that the simpler arguments are preferred for non advanced usage.There used to be a PR that might be used as a starting point:
https://github.com/huggingface/transformers/pull/12219
Thanks if you can work on this !