question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

how to do downstream task for classification

See original GitHub issue

Hi, I trained simCLR with backbone as follow :

model = torchvision.models.alexnet(pretrained=True)
model = nn.Sequential(*list(model.children())[:-2],nn.AdaptiveAvgPool2d(1))
model_ssl = models.SimCLR(model, num_ftrs=256, out_dim=128)

then, I tried to do downstream task to do classification (8 classes). I loaded the model as follow :

from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available()  else "cpu")
model_ = models.alexnet()
model = nn.Sequential(*list(model_.children())[:-2],nn.AdaptiveAvgPool2d(1))
ckpt = torch.load('alexnet_backbone_params.pth')
model.load_state_dict(ckpt['alexnet_backbone_parameters'])

then, Tried to mimic architecture of original alexnet and add custom classifier as follow :

model = nn.Sequential(*list(model.children())[:-1],nn.AdaptiveAvgPool2d((6,6)))        
model.classifier = nn.Sequential(nn.Linear(9216, 1024),
                                 nn.ReLU(),
                                 nn.Dropout(0.5),
                                 nn.Linear(1024, 8),
                                 nn.LogSoftmax(dim=1))
print(model)
summary(model, (3, 80, 80))

It give me an error like this :

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Linear(in_features=9216, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=1024, out_features=8, bias=True)
    (4): LogSoftmax(dim=1)
  )
)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-14-01ece421e2c0> in <module>
     16 
     17 print(model)
---> 18 summary(model, (3, 80, 80))

~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torchsummary\torchsummary.py in summary(model, input_size, batch_size, device)
     70     # make a forward pass
     71     # print(x.shape)
---> 72     model(*x)
     73 
     74     # remove these hooks

~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\linear.py in forward(self, input)
     92 
     93     def forward(self, input: Tensor) -> Tensor:
---> 94         return F.linear(input, self.weight, self.bias)
     95 
     96     def extra_repr(self) -> str:

~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
   1751     if has_torch_function_variadic(input, weight):
   1752         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753     return torch._C._nn.linear(input, weight, bias)
   1754 
   1755 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3072x6 and 9216x1024)

please help

thank you

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:8 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
ramdhan1989commented, May 26, 2021

I change lr to 0.00001 now it run without any nan error.

0reactions
IgorSusmeljcommented, May 20, 2021

Another thing that could be the problem is the range of input values you feed into the model. What kind of transforms do you use in the dataloader trainloader? Do you use normalization?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Which tasks are called as downstream tasks?
In the context of self-supervised learning (which is also used in NLP), a downstream task is the task that you actually want to...
Read more >
How to fine-tune a model for common downstream tasks
This guide will show you how to fine-tune Transformers models for common downstream tasks. You will use the Datasets library to quickly load...
Read more >
A Quick Note on Graphs and the Formulation of Their ...
Now, let's talk about the graph downstream tasks. ... For this classification task, you can use standard machine learning algorithms or ...
Read more >
Using BERT in downstream tasks
Run each Premise + Ending through BERT. Masked LM (compared to left-to-right LM) is very important on some tasks, Next Sentence Prediction is...
Read more >
how to do downstream task for image classification #35
Hi, is there any example how to do fine tuning / downstream task for classification by adding fc after the backbone ?
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found