question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

The size of tensor a (100) must match the size of tensor b (17) at non-singleton dimension 3

See original GitHub issue

Hello, thanks for your excellent library!

When I intend to prune a pre-trained BERT for 17-classes text classification, my code is:

# -*- coding: UTF-8 -*-
import os
from transformers import BertTokenizer, BertModel
from transformers import BertForSequenceClassification
from textpruner import summary, TransformerPruner, TransformerPruningConfig, inference_time
import directory
from torch.utils.data import DataLoader
from helper.dataset import TextDataset
from run import RunConfig
import multiprocessing
from evaluate import test_pro
import numpy as np
import torch

model = BertForSequenceClassification.from_pretrained(directory.PRETRAIN_DIR, num_labels=17)

model.load_state_dict(torch.load('model/fold_1_best.pth'))

tokenizer = BertTokenizer.from_pretrained(directory.PRETRAIN_DIR)

test_df = test_pro()

test_dataset = TextDataset(test_df, np.arange(test_df.shape[0]))

test_loader = DataLoader(
    test_dataset, batch_size=run_config.batch_size, shuffle=True, num_workers=multiprocessing.cpu_count()
)

print(summary(model))

transformer_pruning_config = TransformerPruningConfig(
    target_ffn_size=1536, target_num_of_heads=6,
    pruning_method='iterative', n_iters=1)

pruner = TransformerPruner(model, transformer_pruning_config=transformer_pruning_config)

pruner.prune(dataloader=test_loader, save_model=True)

tokenizer.save_pretrained(pruner.save_dir)

print(summary(model))

But it occurs:

Calculating IS with loss:   0%|                                                                                                                                | 0/125 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "/home/dell/programme/BERT-pruning/prune.py", line 57, in <module>
    pruner.prune(dataloader=test_loader, save_model=True)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/textpruner/pruners/transformer_pruner.py", line 86, in prune
    save_dir = self.iterative_pruning(dataloader, adaptor, batch_postprocessor, keep_shape, save_model=save_model, rewrite_cache=rewrite_cache)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/textpruner/pruners/transformer_pruner.py", line 149, in iterative_pruning
    head_importance, ffn_importance = self.get_importance_score(dataloader, adaptor, batch_postprocessor)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/textpruner/pruners/transformer_pruner.py", line 397, in get_importance_score
    outputs = model(*batch)
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1556, in forward
    outputs = self.bert(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1018, in forward
    encoder_outputs = self.encoder(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 607, in forward
    layer_outputs = layer_module(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 493, in forward
    self_attention_outputs = self.attention(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 423, in forward
    self_outputs = self.self(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 348, in forward
    attention_scores = attention_scores + attention_mask
RuntimeError: The size of tensor a (100) must match the size of tensor b (17) at non-singleton dimension 3

I found few materials or tutorials about TextPruner, maybe it is a little bit latest.

Please have a look at this bug when you are free. Thanks in advance!

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
airariacommented, Jun 27, 2022

Would you please describe each element of the batch print their shapes like this?

for batch in test_loader:
    print([t.shape for t in batch])
    break

It is weird to seen 17 (the number of labels) appear in the calculation of attentions.

Thanks for your reply. I print the shape of batch in test_loader and train_loader, they both show:

torch.Size([16, 100])
torch.Size([16, 17])

But it works well in training and evaulating orz… I am also confused why 17 in attention for training and evaluating works, but fails in pruning TAT

It looks like the model has wrongly treated the second tensor ([16,17]) as the attention masks, because:

  1. In BertForSequenceClassification, the first argument is input_ids, second the argument is attention_mask by default
  2. The type of the batch is a tuple or a list, not a dict

To solve this, you can either

  1. modify your TextDataset code to make it return a dict like {'input_ids': TenosrA, 'labels': TensorB}

or

  1. define a function that takes a batch and return a new batch of dict with the names as the keys
def batch_postprocessor(batch):
    return {'input_ids':batch[0],'labels':batch[1]}

and then call the pruner as:

pruner.prune(dataloader=test_loader, save_model=True,batch_postprocessor= batch_postprocessor)

ps: Also you have to make sure the first element of the model output is the loss (otherwise you have to define another adaptor function, see ?pruner.prune)

0reactions
stale[bot]commented, Jul 12, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Error training ELMo - RuntimeError: The size of tensor a (5158 ...
The following bug RuntimeError: The size of tensor a (5158) must match the size of tensor b (5000) at non-singleton dimension 1 arises...
Read more >
The size of tensor a (3) must match the size of tensor b (32) at ...
Your output tensor of the size [32,3] 32 is the number of mini-batches and 3 is the output of your neural network e.g....
Read more >
RuntimeError: The size of tensor a (5) must match the size of ...
For the RuntimeError: The size of tensor a (5) must match the size of tensor b (32) at non-singleton dimension 3 , may...
Read more >
Trainer RuntimeError: The size of tensor a (462) must match ...
The issue is with your target label sequences. Some of the label sequences have a length that exceeds the model's maximum generation length....
Read more >
How to fix: RuntimeError: The size of tensor a (2) must match ...
Im using PyTorch and I get the RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found