getting Wrong shape for input_ids, while trying to replicate example
See original GitHub issueerror
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-8-1e9a71fa956b> in <module>
----> 1 all_common_phase_vec = model_sent.encode(all_common_phase)
/flucast/anaconda3/envs/e2/lib/python3.8/site-packages/sentence_transformers/SentenceTransformer.py in encode(self, sentences, batch_size, show_progress_bar, output_value, convert_to_numpy, convert_to_tensor, is_pretokenized)
185
186 with torch.no_grad():
--> 187 out_features = self.forward(features)
188 embeddings = out_features[output_value]
189
/flucast/anaconda3/envs/e2/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
115 def forward(self, input):
116 for module in self:
--> 117 input = module(input)
118 return input
119
/flucast/anaconda3/envs/e2/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/flucast/anaconda3/envs/e2/lib/python3.8/site-packages/sentence_transformers/models/RoBERTa.py in forward(self, features)
32 def forward(self, features):
33 """Returns token_embeddings, cls_token"""
---> 34 output_states = self.roberta(**features)
35 output_tokens = output_states[0]
36 cls_tokens = output_tokens[:, 0, :] # CLS token is first token
/flucast/anaconda3/envs/e2/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/flucast/anaconda3/envs/e2/lib/python3.8/site-packages/transformers/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, output_attentions, output_hidden_states, return_dict)
802 # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
803 # ourselves in which case we just need to make it broadcastable to all heads.
--> 804 extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
805
806 # If a 2D ou 3D attention mask is provided for the cross-attention
/flucast/anaconda3/envs/e2/lib/python3.8/site-packages/transformers/modeling_utils.py in get_extended_attention_mask(self, attention_mask, input_shape, device)
258 extended_attention_mask = attention_mask[:, None, None, :]
259 else:
--> 260 raise ValueError(
261 "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
262 input_shape, attention_mask.shape
ValueError: Wrong shape for input_ids (shape torch.Size([40])) or attention_mask (shape torch.Size([40]))
- error is generated when using pertained model(model_name : roberta-large-nli-stsb-mean-tokens)
system running on
Python 3.8.5 (default, Aug 5 2020, 08:36:46)
[GCC 7.3.0] :: Anaconda, Inc. on linux
package installed in env
Package Version
--------------------- -------------------
argon2-cffi 20.1.0
attrs 20.1.0
backcall 0.2.0
bleach 3.1.5
blis 0.4.1
catalogue 1.0.0
certifi 2020.6.20
cffi 1.14.2
chardet 3.0.4
click 7.1.2
cymem 2.0.3
decorator 4.4.2
defusedxml 0.6.0
entrypoints 0.3
filelock 3.0.12
future 0.18.2
idna 2.10
ipykernel 5.3.4
ipython 7.17.0
ipython-genutils 0.2.0
jedi 0.17.2
Jinja2 2.11.2
joblib 0.16.0
json5 0.9.5
jsonschema 3.2.0
jupyter-client 6.1.7
jupyter-core 4.6.3
jupyterlab 2.2.6
jupyterlab-server 1.2.0
MarkupSafe 1.1.1
mistune 0.8.4
mkl-fft 1.1.0
mkl-random 1.1.1
mkl-service 2.3.0
murmurhash 1.0.2
nbconvert 5.6.1
nbformat 5.0.7
nltk 3.5
notebook 6.1.3
numpy 1.19.1
olefile 0.46
packaging 20.4
pandas 1.1.1
pandocfilters 1.4.2
parso 0.7.1
pexpect 4.8.0
pickleshare 0.7.5
Pillow 7.2.0
pip 20.2.2
plac 1.1.3
preshed 3.0.2
prometheus-client 0.8.0
prompt-toolkit 3.0.6
ptyprocess 0.6.0
pycparser 2.20
Pygments 2.6.1
pyparsing 2.4.7
pyrsistent 0.16.0
python-dateutil 2.8.1
pytz 2020.1
pyzmq 19.0.2
regex 2020.7.14
requests 2.24.0
sacremoses 0.0.43
scikit-learn 0.23.2
scipy 1.5.2
Send2Trash 1.5.0
sentence-transformers 0.3.3
sentencepiece 0.1.91
setuptools 49.6.0.post20200814
six 1.15.0
spacy 2.3.2
srsly 1.0.2
terminado 0.8.3
testpath 0.4.4
thinc 7.4.1
threadpoolctl 2.1.0
tokenizers 0.8.1rc2
torch 1.6.0
torchvision 0.7.0
tornado 6.0.4
tqdm 4.48.2
traitlets 4.3.3
transformers 3.0.2
urllib3 1.25.10
wasabi 0.8.0
wcwidth 0.2.5
webencodings 0.5.1
wheel 0.35.1
xlrd 1.2.0
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (5 by maintainers)
Top Results From Across the Web
Pseudoreplication - UT Math
True replication permits the estimation of variability within a treatment. Without estimating variability within treatments, it is impossible to do statistical ...
Read more >Handling multiple sequences - Hugging Face Course
When you're trying to batch together two (or more) sentences, they might be of different lengths. If you've ever worked with tensors before,...
Read more >Difference in Inference result in JAVA and Swift using same ...
In Java I directly call run in the interpreter and pass the Input in inputIds as well as the output array object and...
Read more >Check me, replicate me: Corrections & replications of my work
I have been trying to promote replications and the validation of the published literature, and I would hope that others would do the...
Read more >Replicates and repeats—what is the difference and is it ... - NCBI
As biological experiments can be complicated, replicate measurements are often taken to monitor the performance of the experiment, but such ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Yes, it works when you downgrade the transformers to version 3.0.2 from 3.1.0.
I had the same problem and downgrade the version worked for me. Thank you a lot ❤️