Importing DINO (VitB-8) with pretrained weights into another project
See original GitHub issueI’m trying to evaluate DINO against human performance and therefore want to import the model with all the provided pretrained weights. I can access the backbone via torchhub and create a linear classifier like in eval_linear. However the dimensions do not match and I think it is due to these reshaping steps in validate_network() :
with torch.no_grad():
if "vit" in args.arch:
intermediate_output = model.get_intermediate_layers(inp, n)
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
if avgpool:
output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
output = output.reshape(output.shape[0], -1)
else:
output = model(inp)
output = linear_classifier(output)
loss = nn.CrossEntropyLoss()(output, target)
How could I integrate this step into my initialization? So far I tried it like this:
import torch
import torch.distributed as dist
from torch import nn
class LinearClassifier(torch.nn.Module):
def __init__(self, dim, num_labels=1000):
super(LinearClassifier, self).__init__()
self.num_labels = num_labels
self.linear = torch.nn.Linear(dim, num_labels)
self.linear.weight.data.normal_(mean=0.0, std=0.01)
self.linear.bias.data.zero_()
def forward(self, x):
# flatten
print('First shape: ', x.shape)
x = x.view(x.size(0), -1)
print('After shpae: ', x.shape)
# linear layer
return self.linear(x)
dist.init_process_group('gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)
# load backbone
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
#Setup linear layer
linear_classifier = LinearClassifier(1536, 1000)
linear_classifier = linear_classifier.cuda()
linear_classifier = nn.parallel.DistributedDataParallel(linear_classifier)
linear_classifier.eval()
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth")['state_dict']
linear_classifier.load_state_dict(state_dict, strict=True)
#Sequentialise
model = torch.nn.Sequential(model,
linear_classifier)
Issue Analytics
- State:
- Created 2 years ago
- Comments:7
Top Results From Across the Web
dino/README.md at main · facebookresearch/dino - GitHub
You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone...
Read more >facebook/dino-vitb8 · Hugging Face
By pre-training the model, it learns an inner representation of images that can then be used to extract features useful for downstream tasks:...
Read more >Load a pre-trained model from disk with Huggingface ...
Assuming your pre-trained (pytorch based) transformer model is in 'model' folder in your current working directory, following code can load your ...
Read more >PyTorch Hub
PyTorch Hub supports publishing pre-trained models (model definitions and pre-trained weights) to a GitHub repository by adding a simple hubconf.py file.
Read more >Transfer Learning — Powering up with Pretrained Models (Read
Feature Extraction : where the pretrained layer is used to only extract features like using BatchNormalization to convert the weights into a ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yes, I think this is equivalent. 😃
Thanks a lot! I think I have come up with a solution based on this. I create a new Model with the loaded Backbone and the loaded linear weights and do the concatenation step in the forward function. It does work and seems to achieve the expected accuracy. Would you kindly let me know whether you think this is equivalent to what is happening in eval_linear.py ?