question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TSDAE layer initialization of encoder and decoder

See original GitHub issue

Hi @kwang2049 and @nreimers

I want to train pretrain a sentence transformer using TSDAE. From my understanding I can either use a language model checkpoint (like bert-base-uncased as it is done in the example script ), an already trained sentence-transformer or training the model completely from randomly initialized weights.

We have previously used all-MiniLM-L6-v2 as a checkpoint where we finetuned with MultipleNegativeRankingLoss with the main downstream task being document clustering (we use summaries of the documents as input to the sentence transformer) . Now we want to see whether we can improve the performance by first using TSDAE.

Now the problem is that TSDAE doesn’t look like it is initalizing properly when following the example training script:

train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True)
When tie_encoder_decoder=True, the decoder_name_or_path will be invalid.
Some weights of BertLMHeadModel were not initialized from the model checkpoint at /home/soeren/.cache/torch/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/ and are newly initialized: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'encoder.layer.3.crossattention.self.query.weight', 'encoder.layer.0.crossattention.self.query.weight', 'encoder.layer.0.crossattention.self.query.bias', 'encoder.layer.0.crossattention.output.LayerNorm.bias', 'encoder.layer.4.crossattention.self.query.bias', 'encoder.layer.1.crossattention.self.query.weight', 'encoder.layer.3.crossattention.output.LayerNorm.weight', 'encoder.layer.3.crossattention.self.query.bias', 'encoder.layer.2.crossattention.output.dense.bias', 'encoder.layer.0.crossattention.output.LayerNorm.weight', 'encoder.layer.5.crossattention.self.key.bias', 'encoder.layer.1.crossattention.self.key.weight', 'encoder.layer.2.crossattention.self.query.weight', 'encoder.layer.0.crossattention.self.value.bias', 'encoder.layer.4.crossattention.output.LayerNorm.weight', 'encoder.layer.2.crossattention.self.key.weight', 'encoder.layer.2.crossattention.self.key.bias', 'encoder.layer.3.crossattention.self.key.bias', 'encoder.layer.5.crossattention.output.dense.weight', 'encoder.layer.1.crossattention.output.LayerNorm.bias', 'encoder.layer.3.crossattention.output.dense.bias', 'encoder.layer.1.crossattention.self.value.weight', 'cls.predictions.bias', 'encoder.layer.4.crossattention.self.key.bias', 'encoder.layer.5.crossattention.self.key.weight', 'encoder.layer.4.crossattention.output.LayerNorm.bias', 'encoder.layer.2.crossattention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'encoder.layer.4.crossattention.self.value.bias', 'encoder.layer.3.crossattention.output.dense.weight', 'encoder.layer.5.crossattention.self.query.bias', 'encoder.layer.0.crossattention.self.value.weight', 'encoder.layer.0.crossattention.self.key.weight', 'encoder.layer.2.crossattention.output.LayerNorm.weight', 'encoder.layer.5.crossattention.self.query.weight', 'encoder.layer.2.crossattention.self.value.bias', 'encoder.layer.4.crossattention.output.dense.bias', 'encoder.layer.5.crossattention.output.dense.bias', 'encoder.layer.2.crossattention.output.LayerNorm.bias', 'encoder.layer.1.crossattention.self.value.bias', 'encoder.layer.3.crossattention.self.value.bias', 'encoder.layer.4.crossattention.self.value.weight', 'encoder.layer.3.crossattention.self.value.weight', 'encoder.layer.4.crossattention.self.key.weight', 'encoder.layer.1.crossattention.output.dense.bias', 'encoder.layer.5.crossattention.output.LayerNorm.bias', 'encoder.layer.4.crossattention.output.dense.weight', 'encoder.layer.2.crossattention.self.query.bias', 'encoder.layer.5.crossattention.self.value.bias', 'encoder.layer.3.crossattention.output.LayerNorm.bias', 'encoder.layer.1.crossattention.self.key.bias', 'encoder.layer.0.crossattention.output.dense.bias', 'encoder.layer.4.crossattention.self.query.weight', 'encoder.layer.3.crossattention.self.key.weight', 'encoder.layer.0.crossattention.self.key.bias', 'encoder.layer.0.crossattention.output.dense.weight', 'encoder.layer.1.crossattention.self.query.bias', 'encoder.layer.1.crossattention.output.LayerNorm.weight', 'encoder.layer.5.crossattention.output.LayerNorm.weight', 'encoder.layer.5.crossattention.self.value.weight', 'encoder.layer.2.crossattention.output.dense.weight', 'encoder.layer.1.crossattention.output.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The following encoder weights were not tied to the decoder ['bert/pooler']

First of all it looks like all the encoders weights are not properly initialized, and I am unsure how the decoder is initialized when using a sentence-transformer. Also the last line says something about the pooler not being tied to the decoder - which is probably refering to the pooling layer of the contextualized word embeddings?

TLDR;

Unsure whether I am using TSDAE properly, since I am initializing from a sentence-transformer that does not have a decoder like BERT’s LMHead

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
kwang2049commented, Feb 3, 2022

Hi @sorenmc,

I am not really sure about your training and inference settings. Here I just comment on the general concern.

Actually all-MiniLM-L6-v2 was trained on a lot (1B examples) of labeled datasets and domains. So in many cases it would outperform a model just trained with TSDAE on the downstream task. As shown in the paper, it would be much better if one could first train with TSDAE on the target downstream task/corpus and then train with supervised loss on the labeled data (from general domains). In other words, TSDAE can be viewed as a pre-training method.

However, this is limited in the training efficiency: You need to train it again whenever a new task/corpus comes. Especially, re-training the model on the training data of all-MiniLM-L6-v2 is not always possible due to the huge size.

To pursue the optimal performance, I would suggest to try post-training methods like AugSBERT (for sentence-level tasks) or QGen/GPL (for query-document-retrieval/IR tasks). These ones can be built on top of the existing effort of supervised training directly and adapt the pre-trained model to the target task/domain.

P.S.: (1) As shown in the paper, TSDAE cannot improve the performance in the post-training manner. (2) To answer your specific question about training epochs, I found just a few epochs like 1~3 could be very enough.

1reaction
kwang2049commented, Jan 23, 2022

Hi @sorenmc,

Really sorry to reply to you this late. I was quite busy during the past week.

Thanks a lot for your attention! Let me explain the warning messages one by one:

  1. “When tie_encoder_decoder=True, the decoder_name_or_path will be invalid.”: This means by setting tie_encoder_decoder=True, the weigths of the encoder and the decoder will be shared (actually copied from the encoder), and thus the decoder_name_or_path will be ignored. This weight tying option can save the total number of training parameter while keeping the same performance (or even improves).
  2. “Some weights of BertLMHeadModel were …”: This message is from Huggingface Transformers. Actually you would see this warning each time you use TSDAE, since TSDAE relies on a decoder and also the cross attention to build connection between the decoding and the encoding.
  3. “You should probably TRAIN this model on a down-stream task …”: This is also from Huggingface. It gives out this warning because in 2., the newly created parameters (e.g. in the cross attention layer) are randomly initialized. They need some proper training and actually through TSDAE they will get well trained.
  4. “The following encoder weights were not tied to the decoder [‘bert/pooler’]”: This is also from Huggingface but related to TSDAE usage. In many Transformers, they get a pooler layer on top of the CLS token representation, designed as the input to some classification task head (of course this pooler can also be used in sentence embedding tasks). Obviously, the decoder will not have such a pooler layer and via 2. the parameters will not be shared, either.

So in summary, you do not need to worry about these warning messages (especially for your case with all-MiniLM-L6-v2) 😃.

BTW, just want to give some advice here if needed (but I think you seem to have understood the point:)): TSDAE works great for domain adaptation/pre-training, so in your case you can try (1) first do TSDAE on the corpus of your downstream task (i.e. unsupervised manner) and (2) then do MultipleNegativeRankingLoss on your downstream task (i.e. supervised manner).

Read more comments on GitHub >

github_iconTop Results From Across the Web

TSDAE | Kaggle
Some weights of RobertaForCausalLM were not initialized from the model ... The following encoder weights were not tied to the decoder ['roberta/pooler'].
Read more >
TSDAE — Sentence-Transformers documentation
This section shows an example, of how we can train an unsupervised TSDAE (Tranformer-based Denoising AutoEncoder) model with pure sentences as training data....
Read more >
Leveraging Pre-trained Language Model Checkpoints for ...
BERT's pre-trained weight parameters are used to both initialize the encoder's weight parameters as well as the decoder's weight parameters. To ...
Read more >
Today Unsupervised Sentence Transformers, Tomorrow ...
We will learn to train these models using the unsupervised Transformer-based Sequential Denoising Auto- Encoder ( TSDAE ) approach.
Read more >
TSDAE Explained | Papers With Code
The model architecture of TSDAE is a modified encoder-decoder Transformer ... is the decoder hidden states within $t$ decoding steps at the $k$-th...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found