Add a sampling_transform callback to generation for arbitrary probability-warps
See original GitHub issue🚀 Feature request
I’d like to add sampling_transform
callback argument to all generate
functions in modeling_utils
that allows arbitrary sampling from the probability distribution during sequence generation. The signature of this function would be (input_ids, next_probs, next_token) -> next_token
.
Motivation
The modeling_utils’s generation function is getting pretty hairy – including parameters for top p, top k, bad tokens, temperature, repetition penalty, etc… Every new way of sampling means that we have to add more parameters and it further complicate the function. I believe the right way of solving this is to provide a function-argument that allows users to express arbitrary rules for the next sample. In the long run, we could replace all the other parameters with a set of pre-baked sampling_transform functions that can be composed at will.
This method scales better to strange warps – for example, in a project I’m working on (https://github.com/turtlesoupy/this-word-does-not-exist) I need to early-terminate sequences if they generate from a large set of bad tokens and need to continue generating if an EOS token is sampled too early. An example is here https://github.com/turtlesoupy/this-word-does-not-exist/blob/260e33a8f420b9be8b1e7260cb03c74d6231686e/title_maker_pro/datasets.py#L386
Your contribution
I’ll attach a sample PR that makes this work for pytorch and non-beam samples. If people like the idea, it should be easy to refine into a full PR that generalize to beam search and tensorflow.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:2
- Comments:7 (3 by maintainers)
Top GitHub Comments
@yjernite good call; the interface would have to be modified slightly to
(input_ids, next_probs, next_token) -> (next_probs)
. You might be able to fold the sampling procedure into the transforms. I know in my case I want to operate after a sample has been chosenVery interesting idea! I think we eventually have to make the generation function more general anyways.
Maybe it’s time to move this whole code:
even to a
Sampler
class with asampler.sample(input_ids, next_token_logits)
which can also include a generic function as proposed.What are your thoughts on this @yjernite @thomwolf @sshleifer @LysandreJik ?