Create a wrapper to obtain word embeddings
See original GitHub issueIs your feature request related to a problem? Please describe. Obtaining word representations using the tokenizer and model classes requires some boilerplate
Describe the solution you’d like What do you think about writing a wrapper that given a list of text sentences, a model name and an embedding function name(how to combine token vectors, eg: “concat”, “sum”, etc) would return the sentence representations?
Describe alternatives you’ve considered Below is the boilerplate code I use in my projects to obtain sentence representations from a BERT model, I’m thinking of a wrapper to do all of that under the hood:
def mean_across_all_tokens(hidden_states):
return torch.mean(hidden_states[-1], dim=1)
def sum_all_tokens(hidden_states):
return torch.sum(hidden_states[-1], dim=1)
def concat_all_tokens(hidden_states):
batch_size, max_tokens, emb_dim = hidden_states[-1].shape
return torch.reshape(hidden_states[-1], (batch_size, max_tokens * emb_dim))
def CLS_token_embedding(hidden_states):
return hidden_states[-1][:, 0, :]
class BertTransformer(BaseEstimator, TransformerMixin):
def __init__(
self,
max_length: int = 60,
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased"),
model = BertModel.from_pretrained("bert-base-uncased", output_hidden_states=True),
embedding_func = mean_across_all_tokens,
combine_sentence_tokens=True
):
self.tokenizer = tokenizer;
self.combine_sentence_tokens = combine_sentence_tokens;
self.embedding_func = embedding_func;
self.model = model
self.model.eval()
self.max_length = max_length
def _tokenize(self, text_list: List[str]) -> Tuple[torch.tensor, torch.tensor]:
# Tokenize the text with the provided tokenizer
input_ids = self.tokenizer.batch_encode_plus(text_list,
add_special_tokens=True,
max_length=self.max_length,
pad_to_max_length=True
)["input_ids"]
return torch.LongTensor(input_ids)
def _tokenize_and_predict(self, text_list: List[str]) -> torch.tensor:
input_ids_tensor = self._tokenize(text_list)
out = self.model(input_ids=input_ids_tensor)
hidden_states = out[2]
if(self.combine_sentence_tokens):
return self.embedding_func(hidden_states)
else:
return hidden_states[-1];
def transform(self, text_list: List[str]):
if isinstance(text_list, pd.Series):
text_list = text_list.tolist()
with torch.no_grad():
return self._tokenize_and_predict(text_list)
def fit(self, X, y=None):
"""No fitting necessary so we just return ourselves"""
return self
Issue Analytics
- State:
- Created 3 years ago
- Comments:10 (9 by maintainers)
Top Results From Across the Web
gensim: models.wrappers.wordrank – Word Embeddings from ...
Useful to handle case-mismatch between training tokens and words in the test set. In case of multiple case variants of a single word,...
Read more >NLP: Contextualized word embeddings from BERT
Step by step tutorial to obtain contextualized token embeddings by utilizing Google's BERT model. Implemented in Google Colaboratory with ...
Read more >Complete Guide to Word Embeddings - NLP-FOR-HACKERS
All you need to know about word embeddings. Learn how to build a Word2Vec model and visualize the resulting vectors in 2D with...
Read more >Word Embeddings in Keras - Suresh Pasumarthi
Word embeddings are a way of representing words, to be given as input to ... word dictionary using the gensim wrapper created for...
Read more >How to Use Word Embedding Layers for Deep Learning with ...
Keras offers an Embedding layer that can be used for neural networks on text data. It requires that the input data be integer...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Sounds perfect!
I think we can leave out the fine-tune method as that could potentially be done through one of the other tasks (e.g. Language Modeling). The
RepresentationModel
can then be used to get the model representation for any input (independent of the task-specific head).This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.