question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Why we need the init_weight function in BERT pretrained model

See original GitHub issue

❓ Questions & Help

I have already tried asking the question is SO, which you can find the link here.

Details

In the code by Hugginface transformers, there are many fine-tuning models have the function init_weight. For example(here), there is a init_weight function at last. Even though we use from_pretrained, it will still call the constructor and call init_weight function.

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

As I know, it will call the following code

def _init_weights(self, module):
    """ Initialize the weights """
    if isinstance(module, (nn.Linear, nn.Embedding)):
        # Slightly different from the TF version which uses truncated_normal for initialization
        # cf https://github.com/pytorch/pytorch/pull/5617
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    elif isinstance(module, BertLayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

My question is If we are loading the pre-trained language model, why do we need to initialize the weight for every module?

I guess I must be misunderstanding something here.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

26reactions
BramVanroycommented, Jun 1, 2020

Have a look at the code for .from_pretrained(). What actually happens is something like this:

  • find the correct base model class to initialise
  • initialise that class with pseudo-random initialisation (by using the _init_weights function that you mention)
  • find the file with the pretrained weights
  • overwrite the weights of the model that we just created with the pretrained weightswhere applicable

This ensure that layers were not pretrained (e.g. in some cases the final classification layer) do get initialised in _init_weights but don’t get overridden.

3reactions
BramVanroycommented, Nov 26, 2021

@sunersheng No, the random initialization happens first and then the existing weights are loaded into it.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Initializing the weights of the final layer of e.g. ...
First of, I'm wondering how the final layer is initialized in the first place when I load my model using ...
Read more >
BERT Explained: State of the art language model for NLP
Using BERT, a NER model can be trained by feeding the output vector of each token into a classification layer that predicts the...
Read more >
How much does pre-trained information help? Partially re ...
method called partial re-initialization to understand the performance gains ... We test this by starting from a pre-trained BERT model and fine-tuning it....
Read more >
Transfer Learning — Powering up with Pretrained Models (Read
Bonus: Use Huggingface Transformers Pretrained Models with Internet Off ... we might want to extract just the lower levels layers in BERT to ......
Read more >
How to Code BERT Using PyTorch - Tutorial With Examples
When you use a pre-trained model, all you need to do is download the model and then call it inside a class and...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found