125m checkpoint outputting gibberish
See original GitHub issueConverting the sharded checkpoints of 125m to a singleton checkpoint with https://github.com/facebookresearch/metaseq/pull/60:
$ ls 125m
dict.txt
gpt2-merges.txt
gpt2-vocab.json
reshard-model_part-0.pt
reshard-model_part-1.pt
$ python -m metaseq.scripts.convert_to_singleton 125m
gives a new
restored.pt
file.
I then transformed the checkpoint into the same format as 350m to test some generation on it:
import torch
orig_state = torch.load("./reshard-model_part-0.pt")
model = torch.load("./restored.pt")
orig_state["model"] = model # this format allows one to use the standard `checkpoint_utils.load_model_ensemble_and_task` function
orig_state["cfg"]["model"]._name = "transformer_lm" # we change the architecture name to "transformer_lm" to be able to run it in a non-CUDA environment
torch.save(orig_state, "./reshard.pt")
I tried running an inference example on the model to see whether the generation works as expected. Here the code:
import os
from transformers import GPT2Tokenizer
from metaseq import checkpoint_utils
import torch
path = "/home/patrick/add_opt"
"""
$ ls path
vocab.json
merges.txt
reshard.pt
"""
tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
tokenizer.save_pretrained(path)
paths = [os.path.join(path, "reshard.pt")]
checkpoint = checkpoint_utils.load_model_ensemble_and_task(
paths,
arg_overrides={
"vocab_filename": os.path.join(path, "vocab.json"),
"merges_filename": os.path.join(path, "merges.txt"),
}
)
model = checkpoint[0][0].eval()
# forward passes
def single_batch_forward_logits(prompts):
input_ids = tokenizer(prompts, return_tensors="pt").input_ids
input_ids = torch.cat([torch.tensor([[2]]), input_ids], dim=-1)
logits = model(input_ids)[0]
return logits
prompts = [
"Today is a beautiful day and I want to",
"In the city of",
"Paris is the capital of France and",
"Computers and mobile phones have taken",
]
print("Next word generation")
for prompt in prompts:
print("-------------")
print(f"Prompt: {prompt}...\n")
logits = single_batch_forward_logits(prompt)
pred_next_token = torch.argmax(logits[0, -1], -1)
next_token = tokenizer.convert_ids_to_tokens([pred_next_token])
next_token = next_token[0].replace("Ġ", "")
print(f"Next word: {next_token}")
print("-------------")
This sadly gives gibberish:
Next word generation
-------------
Prompt: Today is a beautiful day and I want to...
Next word: Robbins
-------------
-------------
Prompt: In the city of...
Next word: of
-------------
-------------
Prompt: Paris is the capital of France and...
Next word: Robbins
-------------
-------------
Prompt: Computers and mobile phones have taken...
Next word: Robbins
-------------
Note that this script works perfectly fine with the 350m checkpoint.
@stephenroller - any ideas?
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:8 (6 by maintainers)
Top Results From Across the Web
Conference on Empirical Methods in Natural Language ...
... using English data, but choose a checkpoint with the target dev set. ... with gibberish input data paired with the victim's labels...
Read more >daphne-ippolito-thesis.pdf - UPenn CIS
to gibberish. Most language models make the assumption that the likelihood of a word is dependent only on the words that precede it....
Read more >Untitled
Html5 image map poly, Weak type checking, Multiple output bench power supply, ... firewall checkpoint, Delta simmons tour stroller, Doosje visitekaartjes!
Read more >Viewing online file analysis results for 'JVC_49977.vbs'
... apasote cellarage overdistortion accredit gibberish idiom craws no-side ... S49d121xaH$77=o117U<116x116^,+ilf63K31O^gwTz63d103m &)125m<tv124K_107@ ...
Read more >Handbook for PyeongChang 2018 Volunteers
The Stadium which was built in 2009 accommodates 13,500 seats and has 2 hills for game (Normal hill. K-98m / Large hill K-125m)...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@patrickvonplaten I have a conversion script that our team used to convert the fairseq models to XGLM. Can use it to test it out if that works? I am stuck on the script though, since it requires 2 GPU instances(?) to fully function.
Will try this in metaseq today.