question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

HuggingFace `Trainer` class compatibility

See original GitHub issue

Hi all,

First off, thank you for developing this awesome library! I have a question regarding the compatibility with the Trainer class in HuggingFace.

Background

I am gathering a set of NLP benchmarks to test some models and have already spent significant time developing my code-base with the HuggingFace transformer library. Some of my newer benchmarks require approaches only available out-of-the-box in sentence-transformers; such as training a bi-encoder with CosineSimilarityLoss. For consistency reasons, I would ideally like to stick to the HuggingFace framework for most of my code-base instead of using two separate frameworks.

Question

In light of this, I was wondering whether it is possible to train/fine-tune a SentenceTransformer by using a modified HuggingFace Trainer class (with CosineSimilarityLoss for example)? Something along the lines of:

# define model components
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=256)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
dense_model = models.Dense(in_features=pooling_model.get_sentence_embedding_dimension(),
out_features=256, activation_function=nn.Tanh())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, dense_model])

# use modified HuggingFace trainer for training
trainer = CustomTrainer(model=model, ...)
trainer.train()

Apologies if this was already answered elsewhere, I tried to search for similar threads but could not find much.

Thank you for your time.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:1
  • Comments:6 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
atreyashacommented, Mar 21, 2022

@paluchnuggets Unfortunately not as yet 😕

But I had a look at the some of the methods that need to be overridden in HF’s Trainer class. They list 12 methods, but I think the most important method to override to produce a MWE should be compute_loss. So if we could port the CosineSimilarityLoss into compute_loss, we could already have something to look at.

I’ll probably try this in the next days and report back. Feel free to also add comments/ideas 😃

0reactions
avinashronankicommented, Nov 3, 2022

@atreyasha Thank you, I will take a look

Read more comments on GitHub >

github_iconTop Results From Across the Web

Trainer - Hugging Face
The API supports distributed training on multiple GPUs/TPUs, mixed precision through NVIDIA Apex and Native AMP for PyTorch. The Trainer contains the basic...
Read more >
Trainer — transformers 3.0.2 documentation - Hugging Face
The API supports distributed training on multiple GPUs/TPUs, mixed precision through NVIDIA Apex for PyTorch and tf.keras.mixed_precision for TensorFlow.
Read more >
Training and fine-tuning — transformers 3.3.0 documentation
Model classes in Transformers are designed to be compatible with native PyTorch and TensorFlow 2 and can be used seemlessly with either. In...
Read more >
Trainer — transformers 4.4.2 documentation - Hugging Face
The Trainer class is optimized for Transformers models and can have surprising behaviors when you use it on other models. When using it...
Read more >
Trainer — transformers 4.5.0.dev0 documentation
The Trainer class is optimized for Transformers models and can have surprising behaviors when you use it on other models. When using it...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found