question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Create a wrapper to obtain word embeddings

See original GitHub issue

Is your feature request related to a problem? Please describe. Obtaining word representations using the tokenizer and model classes requires some boilerplate

Describe the solution you’d like What do you think about writing a wrapper that given a list of text sentences, a model name and an embedding function name(how to combine token vectors, eg: “concat”, “sum”, etc) would return the sentence representations?

Describe alternatives you’ve considered Below is the boilerplate code I use in my projects to obtain sentence representations from a BERT model, I’m thinking of a wrapper to do all of that under the hood:

def mean_across_all_tokens(hidden_states):
    return torch.mean(hidden_states[-1], dim=1)

def sum_all_tokens(hidden_states):
    return torch.sum(hidden_states[-1], dim=1)

def concat_all_tokens(hidden_states):
    batch_size, max_tokens, emb_dim = hidden_states[-1].shape
    return torch.reshape(hidden_states[-1], (batch_size, max_tokens * emb_dim))

def CLS_token_embedding(hidden_states):
    return hidden_states[-1][:, 0, :]

class BertTransformer(BaseEstimator, TransformerMixin):
    def __init__(
            self,
            max_length: int = 60,
            tokenizer = BertTokenizer.from_pretrained("bert-base-uncased"),
            model = BertModel.from_pretrained("bert-base-uncased", output_hidden_states=True),
            embedding_func = mean_across_all_tokens,
            combine_sentence_tokens=True
    ):
        self.tokenizer = tokenizer;
        self.combine_sentence_tokens = combine_sentence_tokens;
        self.embedding_func = embedding_func;
        self.model = model
        self.model.eval()
        self.max_length = max_length

    def _tokenize(self, text_list: List[str]) -> Tuple[torch.tensor, torch.tensor]:
        # Tokenize the text with the provided tokenizer
        input_ids = self.tokenizer.batch_encode_plus(text_list,
                                                    add_special_tokens=True,
                                                    max_length=self.max_length,
                                                    pad_to_max_length=True
                                                    )["input_ids"]

        return torch.LongTensor(input_ids)
         

    def _tokenize_and_predict(self, text_list: List[str]) -> torch.tensor:
        input_ids_tensor = self._tokenize(text_list)
        out = self.model(input_ids=input_ids_tensor)
        hidden_states = out[2]
        if(self.combine_sentence_tokens):
            return self.embedding_func(hidden_states)
        else:
            return hidden_states[-1];
    
    def transform(self, text_list: List[str]):
        if isinstance(text_list, pd.Series):
            text_list = text_list.tolist()

        with torch.no_grad():
            return self._tokenize_and_predict(text_list)

    def fit(self, X, y=None):
        """No fitting necessary so we just return ourselves"""
        return self

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:10 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
ThilinaRajapaksecommented, Jul 14, 2020

Sounds perfect!

I think we can leave out the fine-tune method as that could potentially be done through one of the other tasks (e.g. Language Modeling). The RepresentationModel can then be used to get the model representation for any input (independent of the task-specific head).

0reactions
stale[bot]commented, Sep 14, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Read more comments on GitHub >

github_iconTop Results From Across the Web

gensim: models.wrappers.wordrank – Word Embeddings from ...
Useful to handle case-mismatch between training tokens and words in the test set. In case of multiple case variants of a single word,...
Read more >
NLP: Contextualized word embeddings from BERT
Step by step tutorial to obtain contextualized token embeddings by utilizing Google's BERT model. Implemented in Google Colaboratory with ...
Read more >
Complete Guide to Word Embeddings - NLP-FOR-HACKERS
All you need to know about word embeddings. Learn how to build a Word2Vec model and visualize the resulting vectors in 2D with...
Read more >
Word Embeddings in Keras - Suresh Pasumarthi
Word embeddings are a way of representing words, to be given as input to ... word dictionary using the gensim wrapper created for...
Read more >
How to Use Word Embedding Layers for Deep Learning with ...
Keras offers an Embedding layer that can be used for neural networks on text data. It requires that the input data be integer...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found