[Feature Contribution] Disjunctive Positive Constraint Decoding (adding `force_tokens` to `model.generate()`)
See original GitHub issueš Feature request
āDisjunctive Positive Constraint Decodingā Algorithm proposed by a recent paper: Guided Generation of Cause and Effect
Currently the model.generation() method limits the user to just the highest probability outputs. Here the user is able to force diverse outputs by forcing the model to include diverse tokens across multiple generations.
This method is called āDisjunctive Positive Constraint Decodingā, and it forces the model.generate()
process to generate sequences with the highest probabilities under the constraint of needing to include a set of provided tokens.
This ādisjunctiveā method is powerful in that it can handle lemmatizing these forced tokens. For instance, when asking the model to autoregressively generate the completion tokens from āBabies cry becauseā and want to force the generation to include the word ālonelyā, it can induce the model to generate sequences like āBabies cry because they are lonelyā, as well as āBabies cry because of their lonelinessā.
I think that this could be implemented as:
model.generate(force_tokens=[["lonely", "loneliness", ...], ["happy", "happiness", ...]])
where the input to force_tokens
is a 2D array, where each 1D array is a list of different forms of the desired token.
Otherwise it could be:
model.generate(force_tokens=["lonely", "happy"])
but in this case the transformers library would need to have a built-in lemmatization engine, which I think might be something better left for the practictioner to figure out (i.e. go with the first method instead).
Motivation
Diversity in outputs from sequence generation of LMs have always been an active problem. Though usually diversity inducing methods involved some dual implementation with a VAE or modifying the training scheme, this feature would allow the practioner to induce diversity in a very controllable way.
A large pretrained language model probably has the capacity to generate all kinds of different expressions given an input, but usually the generation gets limited to the highest probable outputs. Clearly one solution is to use sampling instead, but this goes back to the problem of controllability. This level of control is extremely useful in model implementations that aim to learn a syntactic transformations that need to preserve certain entities, or QA verbalizers where we have pre-existing knowledge of what the answer should be.
Instead of making it generate a lot of sequences and filtering out for desired ones, this would allow to force it to generate an output that we want, which a large LM probably can do well; even if it canāt figure out a way, then we can just filter out the low probabability outputs based on a threshold.
Your contribution
Iām happy to submit a PR for the full implementation if there arenāt any reasons to object this feature.
But I do believe I should get some feedback on this idea before proceeding with an implementation, since itās not exactly clear whatās the best way to introduce this functionality to the library.
Issue Analytics
- State:
- Created 2 years ago
- Comments:13 (12 by maintainers)
Thanks for reviewing this thread @patrickvonplaten @Narsil
Though itād be ideal if it can be solved with a simple custom
LogitsProcessor
, it seems like this problem requires at least a dedicated beam search function (LexicallyConstrainedBeamSearch
/DisjunctivlyConstrainedBeamSearch
).Upon further research I realized that similar features already exist in Fairseq and Sockeye.
Fairseq implementation is introduced by this readme and the implementation is here (LexicallyConstrainedBeamSearch) Sockeyeās implementation is here.
These implementations are based on mainly the following papers: Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting
I feel like the fact that implementations exist in other major text generation libraries hints at the importance of such a function and similar ideas about āwhitelistingā certain tokens (as opposed to blacklisting with bad_token_ids) have been discussed before in huggingface forums.
I think this is even more worthy of a cause since what I had proposed with this āDisjunctive Positive Constraint Decodingā approach is one step above all the above implementations in that it can handle lemmatized constraints.
For example, users using Sockeye or Fairseq can only force the model generation to include the word ārainā and will end up prevent it against generating the word ārainingā. On the other hand, this disjunctive approach is able to instead say āgenerate one of {rain, raining}ā for better intuitive use of the function.
As one can imagine, implementing even just the simple lexically constraint approach found in fairseq and Sockeye requires a dedicated beam search function and itās much more complex than boosting / reducing logits at each step with a custom
LogitsProcessor
.Iām wondering if such an approach makes the scale of this implementation too large and complex for merging to master. Iām personally more than willing to write the full implementation, with the boosted confidence since half the work is done with other libraries having similar implementations already.
This went under the radar sorry about this.
Iāll let patrick discuss actual inclusion within transformers, but FYI, weāre enabling generic
logits_processor
which should enable you to arbitrarily reassign logits during generation. https://github.com/huggingface/transformers/pull/12219If you could create your implementation framed as a
LogitsProcessor
type of objects, that would make inclusion super simple (and also even if it does not get merged, usage should be quite smooth).