question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] BCELoss should not be masked

See original GitHub issue

I have trained Tacotron2 but during eval / inference, it often doesn’t know when to stop decoding. This is a known issue in seq2seq models and i was trying to solve it in TensorFlowTTS when i gave up due to Tensorflow problems.

Training with enable_bos_eos=True helps a bit but the output is still 3x the ground truth mel length for shorter audio: see length_data_eos.csv vs length_data_no_eos.csv

One reason is the BCELossMasked criterion – in its current form, it encourages the model never to stop decoding once it has passed mel_length. Some of the loss results don’t quite make sense, as seen below:

import torch
def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
    # B x T_max
    mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
    return mask

from torch.nn import functional
length = torch.tensor([95])
mask = sequence_mask(length, 100)
pos_weight = torch.tensor([5.0])
target = 1. - sequence_mask(length - 1, 100).float()  # [0, 0, .... 1, 1] where the first 1 is the last mel frame
true_x = target * 200 - 100  # creates logits of [-100, -100, ... 100, 100] corresponding to target
zero_x = torch.zeros(target.shape) - 100.  # simulate logits if it never stops decoding
early_x = -200. * sequence_mask(length - 3, 100).float() + 100.  # simulate logits on early stopping
late_x = -200. * sequence_mask(length + 1, 100).float() + 100.  # simulate logits on late stopping

# if we mask
>>> functional.binary_cross_entropy_with_logits(mask * true_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(3.4657)  # Should be zero! It's not zero because of trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * zero_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)
>>> functional.binary_cross_entropy_with_logits(mask * late_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(503.4657)  # Stopping late should be better than not stopping at all. Again due to trailing zeros in the mask
>>> functional.binary_cross_entropy_with_logits(mask * early_x, mask * target, pos_weight=pos_weight, reduction='sum')
tensor(203.4657)  # Early stopping should be worse than late stopping because the audio will be cut

# if we don't mask
>>> functional.binary_cross_entropy_with_logits(true_x, target, pos_weight=pos_weight, reduction='sum')
tensor(0.)  # correct
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=pos_weight, reduction='sum')
tensor(3000.)  # correct
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=pos_weight, reduction='sum')
tensor(1000.)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=pos_weight, reduction='sum')
tensor(200.)  # still wrong

# pos_weight should be < 1 to penalize early stopping
>>> functional.binary_cross_entropy_with_logits(zero_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(120.0000)
>>> functional.binary_cross_entropy_with_logits(late_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(40.0000)
>>> functional.binary_cross_entropy_with_logits(early_x, target, pos_weight=torch.tensor([0.2]), reduction='sum')
tensor(200.)  # correct

For now i am passing length=None to avoid the mask and setting pos_weight=0.2 to experiment. Will update the training results.

Additional context

I would also propose renaming stop_tokens to either stop_probs or stop_logits depending on context. Currently, inference() produces stop_tokens that represent stop probabilities, while forward() produces the logits before sigmoid. Confusingly, both are called stop_tokens.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:19 (10 by maintainers)

github_iconTop GitHub Comments

2reactions
Edressoncommented, Feb 23, 2022

Follow some steps:

1st Fork the 🐸 TTS repository (use the button “fork” at the top of the page)

2st Clone from your Fork (dev branch). The command will be some like: git clone https://github.com/iamanigeeit/TTS.git -b dev

3st Change the files that you need.

4st Commit the changes with the commands (obs: Change the commit message 😃):

git add .
git commit -m "Commit message"

5st Push the commits to your fork with the command: git push

6st Go to your fork (https://github.com/iamanigeeit/TTS). Github will identify that you’ve made changes and suggest the pull request and it will show a pull request button below “Go to file”, “Add file” and “code” buttons. Now you can click on the pull request button and send a pull request from your dev branch to Coqui’s dev branch 😃.

1reaction
iamanigeeitcommented, Feb 11, 2022

I used the same config as recipes/ljspeech/tacotron2-DCA/train_tacotron_dca.py except batch_size=32 (due to GPU memory limit) and r=1 (i think r=1 is the correct one for Tacotron2). Training was for 100k steps each.

Read more comments on GitHub >

github_iconTop Results From Across the Web

BCELoss vs BCEWithLogitsLoss - PyTorch Forums
Just to clarify, if using nn.BCEWithLogitsLoss(target, output) , output should be passed through a sigmoid and only then to BCEWithLogitsLoss ?
Read more >
BCE loss not working (#25) · Issues · DAI / ODEON Landcover · GitLab
The error message say: result type Float can't be cast to the desired output type Long. It's because masks tensor is always pass...
Read more >
Pytorch BCELoss not accepting lists - Stack Overflow
target needs to be a tensor, not a list of tensors. ... Hi I solved it by using torch.stack . Could have used...
Read more >
Simple Neural Network with BCELoss for Binary classification ...
In this blog, we will be focussing on how to use BCELoss for a simple neural network in Pytorch. Our dataset after preprocessing...
Read more >
BCELossWithLogits(input) != BCELoss(Sigmoid(input)) #24933
Maybe at::mean() is using a trick to avoid floating point error, while BCELoss 's reduction method is not. If so, I suppose the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found