Tensorflow Mixed Precision Training
See original GitHub issueEnvironment info
transformers
version: 4.5.1- Platform: Linux-5.4.0-74-generic-x86_64-with-glibc2.27
- Python version: 3.8.8
- PyTorch version (GPU?): 1.8.1+cu111 (True)
- Tensorflow version (GPU?): 2.6.0-dev20210604 (True)
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help
@LysandreJik @Rocketknight1 @sgugger
Information
Model I am using (Bert, XLNet …): Bert
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- my own task or dataset: (give details below)
To reproduce
Steps to reproduce the behavior:
- Use TF directly with
model.fit
orTFTrainer
with policymixed_float16
for mixed precision training. - Due to this tensorflow cast issue in SparseCategoricalCrossentropy loss used in many of the huggingface TF models, incorrect label encodings could result in
nan
or errors in loss. - Errors can start with token (or class) indexes at 2k+ and
nan
loss with labels closer to the max.
Expected behavior
Correct loss and no nan
.
Changing compute_loss
to use CategoricalCrossentropy
vs sparse and manually one hot encoding solves this:
def compute_loss(labels, logits):
loss_fn = tf.keras.losses.CategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
**one_hot_labels = tf.one_hot(labels, tf.shape(logits)[-1], dtype=logits.dtype)**
return loss_fn(one_hot_labels, reduced_logits)
Changing the last output layer to be float32 also solves this:
class TFBertMLMHead(tf.keras.layers.Layer):
def __init__(self, config: BertConfig, input_embeddings: tf.keras.layers.Layer, **kwargs):
super().__init__(**kwargs)
self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
**self.finalCast = tf.keras.layers.Activation('linear', dtype='float32')**
def call(self, sequence_output: tf.Tensor) -> tf.Tensor:
prediction_scores = self.predictions(hidden_states=sequence_output)
**prediction_scores = self.finalCast(prediction_scores)**
return prediction_scores
But given the recommendation that output be accumulated in float32 to be numerically stable, perhaps transform_act_fn
and everything after needs to be float32
?
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
Mixed precision | TensorFlow Core
Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and...
Read more >Speed up your TensorFlow Training with Mixed Precision on ...
Mixed precission can speed up training on certain GPUs and TPUs. When using tf.keras.model.fit to train your model, the only step required is...
Read more >Mixed precision - Keras
Mixed precision training is the use of lower-precision operations ( float16 and bfloat16 ) in a model during training to make it run...
Read more >Train With Mixed Precision - NVIDIA Documentation Center
Mixed precision is the combined use of different numerical precisions in a computational method. Half precision (also known as FP16) data ...
Read more >Automatic Mixed Precision in TensorFlow for Faster AI ...
Mixed precision training utilizes half-precision to speed up training, achieving the same accuracy in some cases as single-precision training using the same ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
In PyTorch, we always compute the softmax in FP32 as it’s better for numerical stability. So yes, if possible, we should the same on the TF side.
This is a fascinating bug in Keras. It’s a known issue that softmaxes can be unstable in float16 or bfloat16, but I didn’t realize that this issue could also smear the labels around too. Tagging #12332 as well, which is a relevant PR. (And maybe this might finally explain my confusion with what was going on in that case!)
I think you’re right that computing the logits in float32 across our models might still be an improvement for numerical stability reasons even if the label cast bug is fixed, though, and so it would be worth making that change even if the upstream Keras bug gets fixed. @sgugger wdyt?