Potential incorrect loss calculation for TFTokenClassification in TFTrainer
See original GitHub issueEnvironment info
transformers
version: 3.1.0- Platform: Linux-4.15.0-115-generic-x86_64-with-debian-buster-sid
- Python version: 3.6.7
- PyTorch version (GPU?): 1.5.1+cpu (False)
- Tensorflow version (GPU?): 2.2.0 (False)
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help
Trainer: @sgugger tensorflow: @jplu examples/token-classification: @stefan-it
Mostly for @jplu, potentially for @stefan-it (because the workaround I have in mind requires a bit change in the token classification dataset).
Information
The problem arises when using:
-
The official example scripts: The involved scripts are:
- https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py
- https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_tf_ner.py
However, in order to demonstrate the issue in a more clear way, I use a minimal example which doesn’t use directly these two scripts. See the description and code snippet below.
The tasks I am working on is:
- Official token classification task in TensorFlow
Description
In trainer_tf.py, the loss calculation is calculated from per_example_loss
divided by total_train_batch_size
.
per_example_loss, _ = self.run_model(features, labels, True)
scaled_loss = per_example_loss / self.total_train_batch_size
Here total_train_batch_size
is the size of a whole batch that will be distributed to (potentially) different replicas and optionally consisting of several smaller batches for accumulation steps.
For sentence level tasks, where each example (i.e. sentence) corresponds to a label (for example, sentence classification), the above loss calculation is correct.
However, for token level tasks like token classification, the the above loss seems incorrect to me. For such tasks, the loss should be the per example losses divided by the number of real tokens involved in the batch.
In utils_ner, convert_examples_to_features
set labels to -100
for padding tokens and other special tokens ([CLS]
, [SEP]
, etc), which are the places to be ignored for loss calculation. Therefore, the loss calculation should be the per example losses divided by the number of labels that are not -100 in the *batch*.
By *batch*, it should be careful that it is not the batch received by a single replica, and neither the smaller batch in a single accumulation step. It means the whole batch that will be distributed to (potentially) different replicas and optionally consisting of several smaller batches for accumulation steps.
More precisely, it means a batch passed to distributed_training_steps() - for the same reason as we divide per example losses by total_train_batch_size
for sentence level tasks, rather than dividing it by the size of batch received by a single replica.
In order to calculate the correct loss values, we have to pass the global information - the number of labels that are not -100
in a global batch
to each replica. I don’t know a clean way to do it, but for my own personal projects, I inject this extra information into global batches as a constant, and each replica receiving a distributed smaller batch will have this information to calculate the correct scaled losses.
(I have a notebook showing how to perform it, if you want to look it, let me know.)
Code Snippets
Here is a minimal example to demonstrate the issue.
Here, we have only one real example (sentence) and n_empty_string
empty sentences.
Each empty sentence will give only [CLS], [SEP] and [PAD] tokens that will be ignored for token classification.
import os
os.environ['TF_DETERMINISTIC_OPS'] = '1'
SEED = 42
name = 'distilbert-base-uncased'
seq_len = 8
num_labels = 2
n_empty_string = 10
import tensorflow as tf
tf.random.set_seed(SEED)
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
from transformers import TFTrainer, AutoConfig, AutoTokenizer, TFAutoModelForTokenClassification
from transformers.training_args_tf import TFTrainingArguments
text = [
'My dog is cute'
]
text.extend([''] * n_empty_string)
n_examples = len(text)
config = AutoConfig.from_pretrained(
name,
num_labels=num_labels
)
tokenizer = AutoTokenizer.from_pretrained(name)
model = TFAutoModelForTokenClassification.from_pretrained(
name
)
training_args = TFTrainingArguments(
output_dir='./tmp/',
per_device_train_batch_size=n_examples,
gradient_accumulation_steps=1,
seed=SEED
)
# Initialize our Trainer
trainer = TFTrainer(
model=model,
args=training_args,
train_dataset=None,
eval_dataset=None,
compute_metrics=None
)
trainer.total_train_batch_size = strategy.num_replicas_in_sync \
* training_args.per_device_train_batch_size \
* training_args.gradient_accumulation_steps
trainer.train_loss = tf.keras.metrics.Sum()
features = tokenizer.batch_encode_plus(text, max_length=seq_len, padding='max_length', return_tensors='tf')
# Set all labels to `1`, except for special tokens: cls/sep/pad, where the labels are `-100`.
labels = tf.constant(1, shape=[n_examples, seq_len])
for token_id in [tokenizer.pad_token_id] + tokenizer.all_special_ids:
labels = labels * tf.cast(features['input_ids'] != token_id, dtype=tf.int32) + \
-100 * tf.cast(features['input_ids'] == token_id, dtype=tf.int32)
# Only the first example `features[0]` has real tokens, the other examples have only [PAD].
print(features['input_ids'])
# Only the first example has labels that won't be ignored.
print(labels)
# Copy from:
# https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py#L601
per_example_loss, _ = trainer.run_model(features, labels, True)
scaled_loss = per_example_loss / trainer.total_train_batch_size
print(scaled_loss)
Expected behavior
When n_empty_string = 0
, we get scaled_loss
tf.Tensor([0.56047076 0.46507886 0.51456743 0.50131255], shape=(4,), dtype=float32)
When n_empty_string = 9
, we get scaled_loss
tf.Tensor([0.05604707 0.04650789 0.05145674 0.05013125], shape=(4,), dtype=float32)
However, in both case, we should get the same value, which should be
tf.Tensor([0.56047076 0.46507886 0.51456743 0.50131255], shape=(4,), dtype=float32)
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (6 by maintainers)
OK now with an example and the explanation I got it. Thank you very much!
I prefer you do a PR and then you get the credit of this fix 😃 And if you can tag me as reviewer I will be able to help you if needed, as there is certainly a nicer way to do. Maybe with a class field?
Thanks again, waiting your PR ^^
No problem! Take the time you need and let me know.