SBERT STS training task using original siamese networks
See original GitHub issueHi. I recently try to use sentenceBERT to classification. I wondered that STS training task is not same as original siamese neural networks paper(Koch, Gregory, Richard Zemel, and Ruslan Salakhutdinov. “Siamese neural networks for one-shot image recognition.” ICML Deep Learning Workshop. Vol. 2. 2015.)
So, I tried to train the New model like below.
BERT -> pooling -> u => 1 norm distance(u, v) -> dense layer -> 0~1 score BERT -> pooling -> v
I wanna ask if there are some opinion this training task. I did not get meaningful results.
I attached loss model code
# CustomDistanceMSELoss.py
import torch
from torch import nn, Tensor
from typing import Union, Tuple, List, Iterable, Dict
from ..SentenceTransformer import SentenceTransformer
import logging
logger = logging.getLogger(__name__)
class CustomDistanceMSELoss(nn.Module):
"""
This loss was used in our SBERT publication (https://arxiv.org/abs/1908.10084) to train the SentenceTransformer
model on NLI data. It adds a softmax classifier on top of the output of two transformer networks.
:param model: SentenceTransformer model
:param sentence_embedding_dimension: Dimension of your sentence embeddings
:param num_labels: Number of different labels
"""
def __init__(self,
model: SentenceTransformer,
sentence_embedding_dimension: int,
num_labels: int,
activation_function: str = None,
linear_num: int = 1):
super(CustomDistanceMSELoss, self).__init__()
self.model = model
self.num_labels = num_labels
# self.classifier = nn.Linear(num_vectors_concatenated * sentence_embedding_dimension, num_labels)
self.act = None
if activation_function == 'tanh':
self.act = nn.Tanh()
elif activation_function == 'sigmoid':
self.act = nn.Sigmoid()
self.linear_num = linear_num
if linear_num==1:
self.classifier = nn.Linear(sentence_embedding_dimension, num_labels)
elif linear_num==2:
self.classifier1 = nn.Linear(sentence_embedding_dimension, sentence_embedding_dimension)
self.classifier2 = nn.Linear(sentence_embedding_dimension, num_labels)
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
rep_a, rep_b = reps
# l1 distance
features = torch.abs(rep_a - rep_b)
if self.linear_num==1:
output = self.classifier(features)
elif self.linear_num==2:
output = self.classifier1(features)
output = nn.Sigmoid()(output)
output = self.classifier2(output)
if self.act:
output = self.act(output)
loss_fct = nn.MSELoss()
if labels is not None:
loss = loss_fct(output, labels.view(-1))
return loss, output
else:
return reps, output
Thank you
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
Training Overview — Sentence-Transformers documentation
We can then train the network with a Siamese Network Architecture (for details see: Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks).
Read more >Sentence Embeddings using Siamese BERT-Networks - arXiv
We evaluate SBERT and SRoBERTa on common STS tasks and transfer learning tasks, where it outperforms other state-of-the-art sentence ...
Read more >Sentence Embeddings using Siamese BERT-Networks
We evaluate the performance of SBERT for com- mon Semantic Textual Similarity (STS) tasks. State-of-the-art methods often learn a (complex) regression function ...
Read more >Voicelab/sbert-base-cased-pl - Hugging Face
Training was based on the original paper Siamese BERT models for the task of semantic textual similarity (STS) with a slight modification of...
Read more >Training Sentence Transformers with Softmax Loss - Pinecone
This article dives deeper into the training process of the first sentence transformer, sentence-BERT, or more commonly known as SBERT.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Because you can then no longer compare embeddings using efficient cosine similarity (other similar similarity functions).
Either you want to compare many embeddings which each other, then you want to use cosine similarity etc.
Or you want to compare individual pairs. In that case, cross-encoders are much better (require less data and achieve higher performances)
Thank you for your answer.