Top_P sampling samples an extra token when the cum sum of probabilities is exactly equal to top_p
See original GitHub issueSystem Info
transformers
version: 4.20.1- Platform: Linux-5.10.133±x86_64-with-debian-bullseye-sid
- Python version: 3.7.12
- Huggingface_hub version: 0.8.1
- PyTorch version (GPU?): 1.11.0+cpu (False)
- Tensorflow version (GPU?): 2.6.4 (False)
- Flax version (CPU?/GPU?/TPU?): 0.6.0 (cpu)
- Jax version: 0.3.16
- JaxLib version: 0.3.15
Who can help?
@patrickvonplaten @Narsil @gante
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
Top p sampling samples an extra token when the cumulative sum of probabilities of token is exactly equal to the given top p. E.g., if the input probabilities is [0.3, 0.1, 0.1, 0.5]
and top_p = 0.8
then only 2 tokens with probability 0.5
and 0.3
should be sampled as their sum would exactly be equal to 0.8
. I believe this is the expected behavior of Top P sampling according to the definition which states that:
top_p (float, optional, defaults to 1.0) — If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
I have created a notebook which reproduces this behavior. The notebook also has a proposed implementation which will fix this with an added optimization of not needing to clone tensor and shifting to left or right. https://www.kaggle.com/ekagra/hf-contrib-topp
I have checked locally that the proposed implementation passes the existing unittest .
Your contribution
If this makes sense then I would be happy to raise a PR for this.
Issue Analytics
- State:
- Created a year ago
- Comments:5 (5 by maintainers)
Top GitHub Comments
@ekagra-ranjan that is fine, as long as you also edit the test for FLAX and TF (as in my PR), to ensure the three frameworks have the same behavior
@gante Actually, I wanted to raise a PR with my implementation because it has an optimization of not requiring to clone an intermediate tensor and shifting things to right (as done in current implementation). I have raised the PR. Could you please review it?