RagTokenForGeneration.from_pretrained fails while running demo script
See original GitHub issueEnvironment info
transformers
version: 3.3.1- Platform: Linux-4.4.0-1113-aws-x86_64-with-debian-stretch-sid
- Python version: 3.7.9
- PyTorch version (GPU?): 1.3.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help
@VictorSanh @patrickvonplaten @sshleifer transformers/modeling_utils.py
Information
Model I am using (Bert, XLNet …): RAG
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- my own task or dataset: (give details below)
To reproduce
Steps to reproduce the behavior:
- install new conda env py=3.7
- install RAG requirements
- run example code from https://huggingface.co/transformers/master/model_doc/rag.html
Python 3.7.9 (default, Aug 31 2020, 12:42:55)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
>>> import torch
>>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
>>> retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
Using custom data configuration dummy.psgs_w100.nq.no_index
Reusing dataset wiki_dpr (/homes/thielk/.cache/huggingface/datasets/wiki_dpr/dummy.psgs_w100.nq.no_index/0.0.0/14b973bf2a456087ff69c0fd34526684eed22e48e0dfce4338f9a22b965ce7c2)
Using custom data configuration dummy.psgs_w100.nq.exact
Reusing dataset wiki_dpr (/homes/thielk/.cache/huggingface/datasets/wiki_dpr/dummy.psgs_w100.nq.exact/0.0.0/14b973bf2a456087ff69c0fd34526684eed22e48e0dfce4338f9a22b965ce7c2)
>>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
stack trace:
Traceback (most recent call last):
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 187, in nti
n = int(s.strip() or "0", 8)
ValueError: invalid literal for int() with base 8: 'del.embe'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 2289, in next
tarinfo = self.tarinfo.fromtarfile(self)
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 1095, in fromtarfile
obj = cls.frombuf(buf, tarfile.encoding, tarfile.errors)
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 1037, in frombuf
chksum = nti(buf[148:156])
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 189, in nti
raise InvalidHeaderError("invalid header")
tarfile.InvalidHeaderError: invalid header
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/site-packages/torch/serialization.py", line 595, in _load
return legacy_load(f)
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/site-packages/torch/serialization.py", line 506, in legacy_load
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 1593, in open
return func(name, filemode, fileobj, **kwargs)
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 1623, in taropen
return cls(name, mode, fileobj, **kwargs)
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 1486, in __init__
self.firstmember = self.next()
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/tarfile.py", line 2301, in next
raise ReadError(str(e))
tarfile.ReadError: invalid header
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/site-packages/transformers/modeling_utils.py", line 927, in from_pretrained
state_dict = torch.load(resolved_archive_file, map_location="cpu")
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/site-packages/torch/serialization.py", line 426, in load
return _load(f, map_location, pickle_module, **pickle_load_args)
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/site-packages/torch/serialization.py", line 599, in _load
raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
RuntimeError: /homes/thielk/.cache/torch/transformers/06fe449ffe41cbe16aeb1f5976989313464a3c44a605e9a8b91bf6440dfa6026.696574d8c17eafbac08f43f01e951252057f8feb133b64a33b76d4c47d65367a is a zip archive (did you mean to use torch.jit.load()?)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/homes/thielk/miniconda3/envs/transformers-pytorch/lib/python3.7/site-packages/transformers/modeling_utils.py", line 930, in from_pretrained
"Unable to load weights from pytorch checkpoint file. "
OSError: Unable to load weights from pytorch checkpoint file. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.
Expected behavior
be able to completely run example code from RAG documentation May be related to #7583
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
import torch
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
# initialize with RagRetriever to do everything in one forward call
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
input_ids = input_dict["input_ids"]
outputs = model(input_ids=input_ids, labels=input_dict["labels"])
# or use retriever seperately
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
# 1. Encode
question_hidden_states = model.question_encoder(input_ids)[0]
# 2. Retrieve
docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
# 3. Forward to generator
outputs = model(context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores, decoder_input_ids=input_dict["labels"])
# or directly generate
generated = model.generate(input_ids=input_dict["input_ids"])
generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
RAG - Hugging Face
RAG is a seq2seq model which encapsulates two core components: a question encoder and a generator. During a forward pass, we encode the...
Read more >AutoTokenizer.from_pretrained fails to load locally saved ...
I am new to PyTorch and recently, I have been trying to work with Transformers. I am using pretrained tokenizers provided by HuggingFace....
Read more >PyTorch-Transformers
An open source machine learning framework that accelerates the path from research prototyping to production deployment.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Okey, after some internal discussion the error is the following. PyTorch changed its
torch.save()
method officially in PyTorch 1.6.0 (check https://github.com/pytorch/pytorch/releases for 1.6.0 under “Deprecations”) which means that models saved with torch >= 1.6.0 are not loadable with torch <= 1.4.0 -> hence this error. So for RAG the minimum required torch version is torch 1.5.0 it seems. (thanks @sgugger @LysandreJik )I can confirm that this error occurs with PyTorch version 1.4.0!