question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How exactly to use GPU with KeyBERT?

See original GitHub issue

I’m trying to extract keywords and keyphrases from around 20k abstracts of journal articles. The FAQ mentions that it is recommended to use GPU with KeyBERT. However, I’m unclear how exactly to run the extract_keywords function on GPU. I tried model = KeyBERT() model.to(device) but it says KeyBERT() has no attribute ‘to’. I’d appreciate some help in implementing KeyBERT on GPU. Thanks!

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:14 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
MaartenGrcommented, Jul 19, 2022

@Amaimersion Let me start off by saying thank you for this extensive search into what exactly is happening here! You are one of the few that goes that much in-depth and it makes my work a whole lot easier 😄

There are a few small things that I have noticed but I believe most of it is indeed due to the KeyphraseCountVectorizer which I will come back to in a bit.

pip install transformers keybert spacy[cuda117] keyphrase_vectorizers

After performing the above, it might be worthwhile to again check whether cuda is enabled. From your results, I am quite sure it is but just to be certain.

TOKENIZER_1 = PegasusTokenizer.from_pretrained(“tuner007/pegasus_paraphrase”, cache_dir=CACHE_DIR)

Thank you for this example, it indeed clearly indicates that GPU is working as it should in pytorch.

call_2() on GPU without vectorizer call_2() on CPU without vectorizer

Based on these, I think you are correct in stating that it is likely the KeyphraseCountVectorizer. In my experiments, the model can be quite slow compared to a SentenceTransformer model for example. The processing it needs to do seems to require much more compute, so it is unsurprising that it slows down quite a bit. Having said that, you should still need to see some improvement when using a cuda-enabled GPU, which you clearly have.

I believe what is happening is a mixture of two things:

  • The lengths of the documents make it a bit misleading
  • KeyphraseCountVectorizer, as a default, actually uses a model optimized for CPU, namely en_core_web_sm

The lengths of the documents make it a bit misleading
This might sound a bit strange seeing as you got the same results regardless of the length of the texts. The misleading part here is that SentenceTransformers simply truncates the text if is passes a certain length but this same process does not happen with the KeyphraseCountVectorizer. Thus, the GPU will only be used for a short time on the truncated text since embedding a single text is relatively quickly. This leads me to the following:

KeyphraseCountVectorizer uses a CPU-optimized model
The default model in KeyphraseCountVectorizer is Spacy’s en_core_web_sm which is optimized for the CPU and not the GPU. What likely happens is that after embedding the documents using the SentenceTransformer, which happens typically quite fast, the KeyphraseCountVectorizer will take some time to generate the candidate keywords.

I think the solution here is to either stop using KeyphraseCountVectorizer or, which I would highly advise testing out, use the en_core_web_trf model instead. That model is, like SentenceTransformer, a transformer model and thereby benefits from using a GPU. This does not mean it will automatically be faster than en_core_web_sm since they differ in size and speed.

1reaction
MaartenGrcommented, May 23, 2022

@thtang This depends on the model that you are using, some support it and others do not. As a default, sentence-transformers is used and will only use a single GPU. However, you can create custom back-ends that support this:

from keybert import KeyBERT
from keybert.backend import BaseEmbedder
from sentence_transformers import SentenceTransformer

class CustomEmbedder(BaseEmbedder):
    def __init__(self, embedding_model):
        super().__init__()
        self.embedding_model = embedding_model

        # all available CUDA devices will be used
        self.pool = self.embedding_model.start_multi_process_pool()

    def embed(self, documents, verbose=False):

        # Run encode() on multiple GPUs
        embeddings = self.embedding_model.encode_multi_process(documents, 
                                                               self.pool)
        return embeddings

# Create custom backend and pass it to KeyBERT
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
custom_embedder = CustomEmbedder(embedding_model=model)
kw_model = KeyBERT(model=custom_embedder)
Read more comments on GitHub >

github_iconTop Results From Across the Web

FAQ - KeyBERT - Maarten Grootendorst
How can I speed up the model?¶. Since KeyBERT uses large language models as its backend, a GPU is typically prefered when using...
Read more >
How to Extract Relevant Keywords with KeyBERT
To use this method, you start by setting the top_n argument to a value, say 20. Then 2 x top_n keywords are extracted...
Read more >
KeyBERT for Keyword Extraction - Mark III Systems
To use KeyBERT, only a few lines of code are required. · from keybert import KeyBERT · Then, create a variable to hold...
Read more >
This keyboard has a CPU and GPU to help create crazy effects
Have you ever wanted to own a keyboard that has its own CPU and GPU? No? Well, it's almost here regardless, and it's...
Read more >
Introducing The GeForce RTX Keyboard Keycap - NVIDIA
... 2022 | Featured Stories Community Contests GeForce RTX GPUs. It's time to bring the magic of RTX technology… to your keyboard!
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found