TSDAE layer initialization of encoder and decoder
See original GitHub issueHi @kwang2049 and @nreimers
I want to train pretrain a sentence transformer using TSDAE.
From my understanding I can either use a language model checkpoint (like bert-base-uncased
as it is done in the example script ), an already trained sentence-transformer
or training the model completely from randomly initialized weights.
We have previously used all-MiniLM-L6-v2 as a checkpoint where we finetuned with MultipleNegativeRankingLoss with the main downstream task being document clustering (we use summaries of the documents as input to the sentence transformer) . Now we want to see whether we can improve the performance by first using TSDAE.
Now the problem is that TSDAE doesn’t look like it is initalizing properly when following the example training script:
train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True)
When tie_encoder_decoder=True, the decoder_name_or_path will be invalid.
Some weights of BertLMHeadModel were not initialized from the model checkpoint at /home/soeren/.cache/torch/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/ and are newly initialized: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'encoder.layer.3.crossattention.self.query.weight', 'encoder.layer.0.crossattention.self.query.weight', 'encoder.layer.0.crossattention.self.query.bias', 'encoder.layer.0.crossattention.output.LayerNorm.bias', 'encoder.layer.4.crossattention.self.query.bias', 'encoder.layer.1.crossattention.self.query.weight', 'encoder.layer.3.crossattention.output.LayerNorm.weight', 'encoder.layer.3.crossattention.self.query.bias', 'encoder.layer.2.crossattention.output.dense.bias', 'encoder.layer.0.crossattention.output.LayerNorm.weight', 'encoder.layer.5.crossattention.self.key.bias', 'encoder.layer.1.crossattention.self.key.weight', 'encoder.layer.2.crossattention.self.query.weight', 'encoder.layer.0.crossattention.self.value.bias', 'encoder.layer.4.crossattention.output.LayerNorm.weight', 'encoder.layer.2.crossattention.self.key.weight', 'encoder.layer.2.crossattention.self.key.bias', 'encoder.layer.3.crossattention.self.key.bias', 'encoder.layer.5.crossattention.output.dense.weight', 'encoder.layer.1.crossattention.output.LayerNorm.bias', 'encoder.layer.3.crossattention.output.dense.bias', 'encoder.layer.1.crossattention.self.value.weight', 'cls.predictions.bias', 'encoder.layer.4.crossattention.self.key.bias', 'encoder.layer.5.crossattention.self.key.weight', 'encoder.layer.4.crossattention.output.LayerNorm.bias', 'encoder.layer.2.crossattention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'encoder.layer.4.crossattention.self.value.bias', 'encoder.layer.3.crossattention.output.dense.weight', 'encoder.layer.5.crossattention.self.query.bias', 'encoder.layer.0.crossattention.self.value.weight', 'encoder.layer.0.crossattention.self.key.weight', 'encoder.layer.2.crossattention.output.LayerNorm.weight', 'encoder.layer.5.crossattention.self.query.weight', 'encoder.layer.2.crossattention.self.value.bias', 'encoder.layer.4.crossattention.output.dense.bias', 'encoder.layer.5.crossattention.output.dense.bias', 'encoder.layer.2.crossattention.output.LayerNorm.bias', 'encoder.layer.1.crossattention.self.value.bias', 'encoder.layer.3.crossattention.self.value.bias', 'encoder.layer.4.crossattention.self.value.weight', 'encoder.layer.3.crossattention.self.value.weight', 'encoder.layer.4.crossattention.self.key.weight', 'encoder.layer.1.crossattention.output.dense.bias', 'encoder.layer.5.crossattention.output.LayerNorm.bias', 'encoder.layer.4.crossattention.output.dense.weight', 'encoder.layer.2.crossattention.self.query.bias', 'encoder.layer.5.crossattention.self.value.bias', 'encoder.layer.3.crossattention.output.LayerNorm.bias', 'encoder.layer.1.crossattention.self.key.bias', 'encoder.layer.0.crossattention.output.dense.bias', 'encoder.layer.4.crossattention.self.query.weight', 'encoder.layer.3.crossattention.self.key.weight', 'encoder.layer.0.crossattention.self.key.bias', 'encoder.layer.0.crossattention.output.dense.weight', 'encoder.layer.1.crossattention.self.query.bias', 'encoder.layer.1.crossattention.output.LayerNorm.weight', 'encoder.layer.5.crossattention.output.LayerNorm.weight', 'encoder.layer.5.crossattention.self.value.weight', 'encoder.layer.2.crossattention.output.dense.weight', 'encoder.layer.1.crossattention.output.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The following encoder weights were not tied to the decoder ['bert/pooler']
First of all it looks like all the encoders weights are not properly initialized, and I am unsure how the decoder is initialized when using a sentence-transformer
. Also the last line says something about the pooler
not being tied to the decoder - which is probably refering to the pooling layer of the contextualized word embeddings?
TLDR;
Unsure whether I am using TSDAE properly, since I am initializing from a sentence-transformer
that does not have a decoder like BERT
’s LMHead
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (2 by maintainers)
Hi @sorenmc,
I am not really sure about your training and inference settings. Here I just comment on the general concern.
Actually
all-MiniLM-L6-v2
was trained on a lot (1B examples) of labeled datasets and domains. So in many cases it would outperform a model just trained with TSDAE on the downstream task. As shown in the paper, it would be much better if one could first train with TSDAE on the target downstream task/corpus and then train with supervised loss on the labeled data (from general domains). In other words, TSDAE can be viewed as a pre-training method.However, this is limited in the training efficiency: You need to train it again whenever a new task/corpus comes. Especially, re-training the model on the training data of
all-MiniLM-L6-v2
is not always possible due to the huge size.To pursue the optimal performance, I would suggest to try post-training methods like AugSBERT (for sentence-level tasks) or QGen/GPL (for query-document-retrieval/IR tasks). These ones can be built on top of the existing effort of supervised training directly and adapt the pre-trained model to the target task/domain.
P.S.: (1) As shown in the paper, TSDAE cannot improve the performance in the post-training manner. (2) To answer your specific question about training epochs, I found just a few epochs like 1~3 could be very enough.
Hi @sorenmc,
Really sorry to reply to you this late. I was quite busy during the past week.
Thanks a lot for your attention! Let me explain the warning messages one by one:
tie_encoder_decoder=True
, the weigths of the encoder and the decoder will be shared (actually copied from the encoder), and thus thedecoder_name_or_path
will be ignored. This weight tying option can save the total number of training parameter while keeping the same performance (or even improves).So in summary, you do not need to worry about these warning messages (especially for your case with all-MiniLM-L6-v2) 😃.
BTW, just want to give some advice here if needed (but I think you seem to have understood the point:)): TSDAE works great for domain adaptation/pre-training, so in your case you can try (1) first do TSDAE on the corpus of your downstream task (i.e. unsupervised manner) and (2) then do MultipleNegativeRankingLoss on your downstream task (i.e. supervised manner).