question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TFWav2Vec2ForCTC & Wav2Vec2ForCTC gives different loss values

See original GitHub issue

Environment info

  • transformers version: 4.8.0.dev0
  • Platform: macOS-10.16-x86_64-i386-64bit
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.8.1 (False)
  • Tensorflow version (GPU?): 2.5.0 (False)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@will-rice @patrickvonplaten

Information

Model I am using: TFWav2Vec2ForCTC & Wav2Vec2ForCTC

To reproduce

Steps to reproduce the behavior:

import tensorflow as tf
import torch
from transformers import Wav2Vec2ForCTC, TFWav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
tf_model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

tf_labels = tf.constant([[3, 54, 65, 76, 21], [32, 42, 434, 76, 231]])
labels = torch.from_numpy(tf_labels.numpy())

tf_speech = tf.random.uniform(shape=(2, 40000))
speech = torch.from_numpy(tf_speech.numpy()).float()

with torch.no_grad():
    out = model(speech, labels=labels)
tf_out = tf_model(tf_speech, labels=tf_labels)

print(out["loss"], tf_out["loss"])
# -> 71.64           -> 16.92

Expected behavior

Loss values from tensorflow & PyTorch model should be similar (Note: logits are perfectly same as expected).

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

3reactions
thevasudevguptacommented, Jun 20, 2021

@will-rice @patrickvonplaten

PyTorch & TensorFlow losses are becoming different if padding indices are set to -100. Checkout this small Colab notebook.

This is happening because these 2 lines [1584 & 1585] should not be there in TensorFlow implementation. If we just remove them, PyTorch & TensorFlow loss will become same.

So:

# we should remove these lines
flattened_labels = tf.boolean_mask(labels, labels_mask)            
flattened_labels = tf.reshape(flattened_labels, [labels.shape[0], -1])

# rather replace it with
flattened_labels = labels
2reactions
patrickvonplatencommented, Jun 20, 2021

It should actually throw an error if labels are > vocab_size! Will open an issue for this

Read more comments on GitHub >

github_iconTop Results From Across the Web

Wav2Vec2 - Hugging Face
CTCLoss . Infinite losses mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an...
Read more >
tf.keras.losses.MeanAbsoluteError | TensorFlow v2.11.0
Computes the mean of absolute difference between labels and ... If a scalar is provided, then the loss is simply scaled by the...
Read more >
Transformers - Spark NLP
In this paper, we investigate the feasibility of training monolingual Transformer-based language models for other languages, taking French as an ...
Read more >
Triplet Loss and Online Triplet Mining in TensorFlow
In this post, I will define the triplet loss and the different strategies ... This technique gives you more triplets for a single...
Read more >
HuggingFace Transformers is giving loss: nan - accuracy
It is about the warning that you have "The parameters output_attentions , output_hidden_states and use_cache cannot be updated when calling ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found