Bert for passage reranking
See original GitHub issueHi I am currently trying to implement bert for passage reranking in pytorch. Here is the paper and github repo. https://arxiv.org/abs/1901.04085 https://github.com/nyu-dl/dl4marco-bert
I’ve downloaded their bert large model checkpoint and bert config for the task the convert_tf_checkpoint_to_pytorch
function seems to successfully extract the weights from tensorflow.
Then while initialising the pytorch model
Initialize PyTorch weight ['bert', 'pooler', 'dense', 'kernel']
Skipping bert/pooler/dense/kernel/adam_m
Skipping bert/pooler/dense/kernel/adam_v
Skipping global_step
35
36 # Load weights from tf checkpoint
---> 37 load_tf_weights_in_bert(model, tf_checkpoint_path)
38
39 # Save pytorch-model
~/anaconda3/envs/new_fast_ai/lib/python3.7/site-packages/pytorch_pretrained_bert/modeling.py in load_tf_weights_in_bert(model, tf_checkpoint_path)
88 pointer = getattr(pointer, 'weight')
89 elif l[0] == 'output_bias' or l[0] == 'beta':
---> 90 pointer = getattr(pointer, 'bias')
91 elif l[0] == 'output_weights':
92 pointer = getattr(pointer, 'weight')
~/anaconda3/envs/new_fast_ai/lib/python3.7/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
533 return modules[name]
534 raise AttributeError("'{}' object has no attribute '{}'".format(
--> 535 type(self).__name__, name))
536
537 def __setattr__(self, name, value):
AttributeError: 'BertForPreTraining' object has no attribute 'bias'
I assume it is issues with the final layer What is the best way for me to go about resolving this?
thanks in advance!
Issue Analytics
- State:
- Created 4 years ago
- Comments:15 (7 by maintainers)
Top Results From Across the Web
[1901.04085] Passage Re-ranking with BERT - arXiv
In this paper, we describe a simple re-implementation of BERT for query-based passage re-ranking. Our system is the state of the art on...
Read more >Passage Re-Ranking | Papers With Code
Passage re-ranking is the task of scoring and re-ranking a collection of retrieved documents based on an input query.
Read more >amberoad/bert-multilingual-passage-reranking-msmarco
Passage Reranking Multilingual BERT ... Purpose: This module takes a search query [1] and a passage [2] and calculates if the passage matches...
Read more >Passage Re-ranking with BERT - arXiv Vanity
In this paper, we describe a simple re-implementation of BERT for query-based passage re-ranking. Our system is the state of the art on...
Read more >Faster BERT-based re-ranking through Candidate Passage ...
Our overall approach to re-ranking is as follows: from each of the top-k documents, extract five candidate passages we expect to be relevant...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Update for latest transformers, add modeling_bert.py:78:
and convert_bert_original_tf_checkpoint_to_pytorch.py
The
convert_tf_checkpoint_to_pytorch
script is made to convert the Google pre-trained weights inBertForPretraining
model, you have to modify it to convert another type model.In your case, you want to load the passage re-ranking model in a
BertForSequenceClassification
model which has the same structure (BERT + a classifier on top of the pooled output) as the NYU model.here is a quick way to do that:
BertForSequenceClassification
model instead of theBertForPreTraining
model in the conversion script.pointer = getattr(pointer, 'cls')
in the TWO if-conditions related tooutput_weights
andoutput_bias
(between L89 and L90 and between L91 and L92 in modeling.py here: https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L90 and https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L92).