question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Electra loss computation

See original GitHub issue

Hello - I looked at simpletransformers/language_modeling/language_modeling_model.py and it seems the loss computation for Electra does not take into account the discriminator loss.

On line 524 we have loss = outputs[0] whereas in line 470 of /simpletransformers/custom_models/models.py we are returning g_loss, d_loss, g_scores, d_scores, d_labels.

This seems as though only the generator loss is being optimized.

In the paper (https://arxiv.org/pdf/2003.10555.pdf) the authors combine the NLL loss from the generator and BCE loss from the discriminator (top of page 4).

Am I missing something?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:29 (14 by maintainers)

github_iconTop GitHub Comments

2reactions
aced125commented, May 24, 2020

Sure

1reaction
Laksh1997commented, May 24, 2020

Hi - so huggingface already ties input and output embeddings of a language model.

All we need to do is tie the input embeddings of the generator and discriminator (the discriminator has no output embeddings).

The code is simple and straightforward:

class ElectraForLanguageModelingModel(PreTrainedModel):
    def __init__(self, generator_config, discriminator_config):
        super().__init__(discriminator_config)
        self.generator_config = generator_config
        self.discriminator_config = discriminator_config
        self.generator_model = ElectraForMaskedLM(generator_config)
        self.discriminator_model = ElectraForPreTraining(discriminator_config)
        self.vocab_size = generator_config.vocab_size
        self.tie_generator_and_discriminator_embeddings()

    def tie_generator_and_discriminator_embeddings(self):
        gen_embeddings = self.generator_model.electra.embeddings
        disc_embeddings = self.discriminator_model.electra.embeddings

        # tie word, position and token_type embeddings
        gen_embeddings.word_embeddings.weight = disc_embeddings.word_embeddings.weight
        gen_embeddings.position_embeddings.weight = (
            disc_embeddings.position_embeddings.weight
        )
        gen_embeddings.token_type_embeddings.weight = (
            disc_embeddings.token_type_embeddings.weight
        )
Read more comments on GitHub >

github_iconTop Results From Across the Web

ELECTRA - Hugging Face
ELECTRA is a new pretraining approach which trains two transformer ... FloatTensor of shape (1,) ) — Total loss of the ELECTRA objective....
Read more >
Loss of base and large models · Issue #3 - GitHub
Hi,. I'm currently working on a new non-English ELECTRA model. Training on GPU seems to work and is running fine.
Read more >
Learning to Sample Replacements for ... - ACL Anthology
Notice that Equation (1) uses the actual discrim- inator loss LD(x ,c), which can not be obtained without feeding xR into the discriminator....
Read more >
Understanding ELECTRA and Training an ELECTRA ...
This clearly shows that the ability to calculate the loss over all input tokens significantly boosts the performance of a pre-trained model.
Read more >
Why I created my own Electra model with memory-efficient ...
Loss over only masked tokens - BERT masks out 15% of the input ... The ELECTRA model trained on GPU for 4 days...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found