Trivial/degenerate solution to Wav2vec 2
See original GitHub issue❓ Questions and Help
Before asking:
- search the issues.
- search the docs.
What is your question?
I was trying to understand Wav2Vec 2.0, and it seems the implementation might lead to trivial solutions in some cases.
Specifically, if the model always assigns positive and negative samples with the same code, it might get very good InfoNCE estimation.
The reason I believe lies in the implementation of compute_pred() in https://github.com/pytorch/fairseq/blob/master/fairseq/models/wav2vec/wav2vec2.py#L478, where if neg_is_pos() evaluates True, the negative will be assigned with -inf value as for logits. In this case, the cross-entropy loss in wav2vec_criterion.py will be trivially minimized, despite that the learnt representation may not be necessarily meaningful.
I wonder is this something by design? Did you encounter this issue before and have any experience on how to avoid it?
Your help is very much appreciated.
Code
def compute_preds(self, x, y, negatives):
neg_is_pos = (y == negatives).all(-1)
y = y.unsqueeze(0)
targets = torch.cat([y, negatives], dim=0)
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
logits = logits / self.logit_temp
if is_xla_tensor(logits) or neg_is_pos.any():
fillval = -float(2 ** 30)
if not hasattr(self, "_inftensor"):
self._inftensor = (
torch.tensor(fillval).to(x.device)
if is_xla_tensor(logits)
else float("-inf")
)
logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor)
return logits
What have you tried?
What’s your environment?
- fairseq Version (master):
- PyTorch Version (1.7)
- OS (e.g., Linux): ubuntu 18
- How you installed fairseq (
pip
, source): source - Build command you used (if compiling from source): pip install --editable ./
- Python version: 3.7
- CUDA/cuDNN version: 10.1
- GPU models and configuration: 2080 Ti
- Any other relevant information:
Issue Analytics
- State:
- Created 2 years ago
- Reactions:10
- Comments:5
Top GitHub Comments
I am facing the same problem. I am training wav2vec with my own dataset using the default
wav2vec2_base.yaml
configuration, but after some time, the training accuracy would drops to zero while the validation accuracy increase to 1.The code perplexity and loss_0 do not look right either
When I try to load the best checkpoint and see what feature wav2vec extracted, I see the output full of
-inf
When I try to get the validation feature, the features are all the same across all time steps
My environment is?
fairseq Version (master): PyTorch Version (1.9) OS (e.g., Linux): ubuntu 18 How you installed fairseq (pip, source): source Build command you used (if compiling from source): pip install --editable ./ Python version: 3.8.10 CUDA/cuDNN version: 11.2 GPU models and configuration: Tesla V100 Any other relevant information: I am training wav2vec with my own dataset, where the average audio lengths are around 7 minutes.
Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!