question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Allow adding custom logits processors in the `generate` method

See original GitHub issue

🚀 Feature request

Hello, I’d like to request a new feature in the generate method of the GenerationMixin class from generation_utils. Specifically, I’d like a feature that allows a user to pass custom LogitsProcessors by adding a new argument logit_processors: Optional[LogitsProcessorList] = None to the generate method.

Motivation

I’d like to run generation on a pre-trained model, and I’d like to modify its output logits according to my custom function before the search or sampling or whatever is used. I think that this could be a common use case for controlled natural generation because one often wants to implement some trivial restrictions over generated logits.

Here is an example of how this could be used:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, LogitsProcessor, LogitsProcessorList

class MyLogitsProcessor(LogitsProcessor):
   def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
      something_useful()


model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
logit_processors = LogitsProcessorList([MyLogitsProcessor()])
input_ids = tokenizer('This dog is cute', return_tensors='pt').input_ids
model.generate(input_ids=input_ids, logit_processors=logit_processors)

Your contribution

I have no experience in open source, but I can try to help if you need a hand. I think that the general approach to implementing this is to do the following:

  1. Add the logit_processors: Optional[LogitsProcessorList] = None argument to the generate method,
  2. Add the same argument to the _get_logits_processor method of GenerationMixin and add the custom logit processors after all the other logit processors are in place.
  3. Pass the custom logits processors to every call of _get_logits_processor in the generate method.

What do you think?

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

3reactions
Narsilcommented, May 14, 2021

I think it’s a very nice idea !.

The problem you mention @patrickvonplaten I think will be relevant mostly for power users (that want to add a LogitsProcessor) so they should be careful in terms of how they use this tool. I guess we could emphasis this in the documentation for the generate function, that the simpler arguments are preferred for non advanced usage.

1reaction
Narsilcommented, Nov 26, 2021

There used to be a PR that might be used as a starting point:

https://github.com/huggingface/transformers/pull/12219

Thanks if you can work on this !

Read more comments on GitHub >

github_iconTop Results From Across the Web

Use custom LogitsProcessor in `model.generate()` - Beginners
As it turns out, you cannot add a custom logits processor list to the model.generate(...) call. You need to use your own beam...
Read more >
Output logits from T5 model for text generation purposes
I effectively want to create my own generate function but I need to obtain the logits of the model to be able to...
Read more >
Models, Preprocessors, and Action Distributions — Ray 2.2.0
TensorFlow: To add a supervised loss to a custom TF model, you need to override the custom_loss() method. This method takes in the...
Read more >
Making new Layers and Models via subclassing - TensorFlow
Setup · The Layer class: the combination of state (weights) and some computation · Layers can have non-trainable weights · Best practice: deferring ......
Read more >
torch.distributions — PyTorch 1.13 documentation
The distributions package contains parameterizable probability distributions and sampling functions. This allows the construction of stochastic computation ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found