Mismatch between sentinel token IDs from T5 data collator and T5 tokenizer
See original GitHub issueEnvironment info
transformers
version: 4.10.3- Platform: Linux-4.18.0-240.el8.x86_64-x86_64-with-glibc2.28
- Python version: 3.9.7
- PyTorch version (GPU?): 1.10.0 (True)
- Tensorflow version (GPU?): 2.6.0 (True)
- Flax version (CPU?/GPU?/TPU?): 0.3.5 (gpu)
- Jax version: 0.2.24
- JaxLib version: 0.1.73
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help
@patil-suraj @patrickvonplaten
Information
I’m trying to use the run_t5_mlm_flax.py
script to do additional pretraining of T5, and I noticed something strange about the way the data collator adds mask/sentinel tokens. In line 293 of run_t5_mlm_flax.py
, the create_sentinel_ids
function replaces the masked positions with the corresponding sentinel IDs as sentinel_ids + self.tokenizer.vocab_size - 1
, which gives values of 32100, 32101, 32102, ...
. However, the sentinel tokens <extra_id_0>, <extra_id_1>, <extra_id_2>, ...
in the tokenizer for t5-base
have the token IDs 32099, 32098, 32097, ...
, which I’m getting from running the following:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-base')
print(tokenizer.convert_tokens_to_ids(['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_99>']))
# prints:
# [32099, 32098, 32097, 32000]
The larger token IDs seem to work without error because the T5ForConditionalGeneration
pretrained model has an extra 128 token embeddings (even though tokenizer.vocab
gives a value of 32100
, which seems to be related to issue #4875), but I’m not sure if these are the same embeddings that were used for the sentinel tokens during the original pretraining. Is the script correct in replacing the mask tokens with token IDs starting from 32100
, even though they don’t correspond to the <extra_id_#>
tokens in the vocabulary?
Here’s an example of the behavior of create_sentinel_ids
alone:
import numpy as np
from transformers import AutoTokenizer
def create_sentinel_ids(mask_indices, tokenizer):
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
start_indices[:, 0] = mask_indices[:, 0]
sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
sentinel_ids = np.where(sentinel_ids != 0, (sentinel_ids + tokenizer.vocab_size - 1), 0)
sentinel_ids -= mask_indices - start_indices
return sentinel_ids
tokenizer = AutoTokenizer.from_pretrained('t5-base')
mask_indices = np.array([[0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1]]).astype(bool)
print(create_sentinel_ids(mask_indices.astype(np.int8), tokenizer))
# prints:
# [[ 0 32100 0 0 32101 -1 0 32102 -1 -1 0 32103]]
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (7 by maintainers)
Sounds good! Submitted a pull request (#14477) with the correct expression for the sentinel token IDs, let me know if I should make any other changes.
Hey @patrickvonplaten , the sentinels start from the end of the vocab (i.e. the first sentinel is vocab_size - 1) and descend: https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py#L2893