question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Use finetuned-BART large to do conditional generation

See original GitHub issue

Hi

I am using a slightly old tag of ur repo where BART had run_bart_sum.py. I finetuned bart-large on a custom data set and want to do conditional generation

from transformers import BartTokenizer, BartForConditionalGeneration
import torch

model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')

ARTICLE_TO_SUMMARIZE = "President Donald Trump's senior adviser and son-in-law, Jared Kushner, praised the administration's response to the coronavirus pandemic as a \"great success story\" on Wednesday -- less than a day after the number of confirmed coronavirus cases in the United States topped 1 million. Kushner painted a rosy picture for \"Fox and Friends\" Wednesday morning, saying that \"the federal government rose to the challenge and this is a great success story and I think that that's really what needs to be told.\""


# model = BartForConditionalGeneration.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
# tokenizer = BartTokenizer.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')

model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')
state = torch.load('./bart_sum/checkpointepoch=2.ckpt',map_location='cpu')
model.load_state_dict(state)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()


inputs = tokenizer.batch_encode_plus(
    [ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
summary_ids = model.generate(
    inputs['input_ids'], num_beams=1, max_length=512, early_stopping=True)


print([tokenizer.decode(g, skip_special_tokens=True,
                        clean_up_tokenization_spaces=False)
       for g in summary_ids])


I tried both loading the finetuned checkpoint directly as well as loading bart-large and setting state dict

For former it gives me

Traceback (most recent call last):
  File "generate.py", line 10, in <module>
    model = BartForConditionalGeneration.from_pretrained('./bart_sum/checkpointepoch=2.ckpt')
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/modeling_utils.py", line 438, in from_pretrained
    **kwargs,
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 200, in from_pretrained
    config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 252, in get_config_dict
    config_dict = cls._dict_from_json_file(resolved_config_file)
  File "/datastor/Softwarez/miniconda3/lib/python3.7/site-packages/transformers/configuration_utils.py", line 344, in _dict_from_json_file
    text = reader.read()
  File "/datastor/Softwarez/miniconda3/lib/python3.7/codecs.py", line 322, in decode
    (result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte

For latter Unexpected key(s) in state_dict: "epoch", "global_step", "checkpoint_callback_best", "optimizer_states", "lr_schedulers", "state_dict", "hparams", "hparams_type".

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:10 (8 by maintainers)

github_iconTop GitHub Comments

5reactions
patil-surajcommented, May 4, 2020

If you used pytorch-lightning for training then you can load the weights from checkpoint as follows

ckpt = torch.load('./bart_sum/checkpointepoch=2.ckpt')
model.load_state_dict(ckpt['state_dict'])

once you load the weights this way then save the model using the .save_pretrained method so that next time you can load it using .from_pretrained

2reactions
ieBoytsovcommented, Jun 15, 2020

@sshleifer thanks for the link, meanwhile i managed to do what i wanted. anyway will be glad to see further improvements for summarisation tasks.

for those who finetuned BART model with finetune_bart.sh and wants to load it in pytorch, the next thing worked for me.

class BartModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')  
      
    def forward(self):
        pass
 
ckpt = torch.load('./bart_sum/checkpointepoch=1.ckpt')
   
bart_model = BartModel()
bart_model.load_state_dict(ckpt['state_dict'])
bart_model.model.save_pretrained("working_dir")
Read more comments on GitHub >

github_iconTop Results From Across the Web

Examples — transformers 2.0.0 documentation - Hugging Face
Language Generation, Conditional text generation using the auto-regressive ... This script can fine-tune the following models: BERT, XLM, XLNet and RoBERTa.
Read more >
Conditional Text Generation by Fine Tuning GPT-2 | by Ivan Lai
In this article, we will fine-tune the Huggingface pre-trained GPT-2 and come up with our own solution: by the choice of data set,...
Read more >
Transformers BART Model Explained for Text Summarization
The BART HugggingFace model allows the pre-trained weights and weights fine-tuned on question-answering, text summarization, conditional text ...
Read more >
Fine-tune a RoBERTa Encoder-Decoder model trained on ...
First, I must admit that probably a text generation problem is not ... Now, we will use that trained model to build an...
Read more >
Fine-Tuning T5 for Question Answering using HuggingFace ...
Prepare for the Machine Learning interview: https://mlexpert.io Subscribe: http://bit.ly/venelin-subscribe Get SH*T Done with PyTorch ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found