how to do downstream task for classification
See original GitHub issueHi, I trained simCLR with backbone as follow :
model = torchvision.models.alexnet(pretrained=True)
model = nn.Sequential(*list(model.children())[:-2],nn.AdaptiveAvgPool2d(1))
model_ssl = models.SimCLR(model, num_ftrs=256, out_dim=128)
then, I tried to do downstream task to do classification (8 classes). I loaded the model as follow :
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_ = models.alexnet()
model = nn.Sequential(*list(model_.children())[:-2],nn.AdaptiveAvgPool2d(1))
ckpt = torch.load('alexnet_backbone_params.pth')
model.load_state_dict(ckpt['alexnet_backbone_parameters'])
then, Tried to mimic architecture of original alexnet and add custom classifier as follow :
model = nn.Sequential(*list(model.children())[:-1],nn.AdaptiveAvgPool2d((6,6)))
model.classifier = nn.Sequential(nn.Linear(9216, 1024),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1024, 8),
nn.LogSoftmax(dim=1))
print(model)
summary(model, (3, 80, 80))
It give me an error like this :
Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(1): AdaptiveAvgPool2d(output_size=(6, 6))
(classifier): Sequential(
(0): Linear(in_features=9216, out_features=1024, bias=True)
(1): ReLU()
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=1024, out_features=8, bias=True)
(4): LogSoftmax(dim=1)
)
)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-14-01ece421e2c0> in <module>
16
17 print(model)
---> 18 summary(model, (3, 80, 80))
~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torchsummary\torchsummary.py in summary(model, input_size, batch_size, device)
70 # make a forward pass
71 # print(x.shape)
---> 72 model(*x)
73
74 # remove these hooks
~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
117 def forward(self, input):
118 for module in self:
--> 119 input = module(input)
120 return input
121
~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
117 def forward(self, input):
118 for module in self:
--> 119 input = module(input)
120 return input
121
~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\modules\linear.py in forward(self, input)
92
93 def forward(self, input: Tensor) -> Tensor:
---> 94 return F.linear(input, self.weight, self.bias)
95
96 def extra_repr(self) -> str:
~\AppData\Local\Continuum\anaconda3\envs\h20\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
1751 if has_torch_function_variadic(input, weight):
1752 return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753 return torch._C._nn.linear(input, weight, bias)
1754
1755
RuntimeError: mat1 and mat2 shapes cannot be multiplied (3072x6 and 9216x1024)
please help
thank you
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (5 by maintainers)
Top Results From Across the Web
Which tasks are called as downstream tasks?
In the context of self-supervised learning (which is also used in NLP), a downstream task is the task that you actually want to...
Read more >How to fine-tune a model for common downstream tasks
This guide will show you how to fine-tune Transformers models for common downstream tasks. You will use the Datasets library to quickly load...
Read more >A Quick Note on Graphs and the Formulation of Their ...
Now, let's talk about the graph downstream tasks. ... For this classification task, you can use standard machine learning algorithms or ...
Read more >Using BERT in downstream tasks
Run each Premise + Ending through BERT. Masked LM (compared to left-to-right LM) is very important on some tasks, Next Sentence Prediction is...
Read more >how to do downstream task for image classification #35
Hi, is there any example how to do fine tuning / downstream task for classification by adding fc after the backbone ?
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I change lr to 0.00001 now it run without any nan error.
Another thing that could be the problem is the range of input values you feed into the model. What kind of transforms do you use in the dataloader
trainloader
? Do you use normalization?