Error in Loading State Dict
See original GitHub issueI am running into an error when trying to reload a fine-tuned version of the models. After further-training, models were saved using the below code:
# Load model
model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t12_85M_UR50S")
# Training code
# Save model
torch.save(model.state_dict(), BEST_MODEL)
Upon trying to reload the model using the below
# Load model
model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t12_85M_UR50S")
model.load_state_dict(torch.load(BEST_MODEL))
I run into the error
RuntimeError Traceback (most recent call last)
<ipython-input-4-b6ed0ebd023b> in <module>
----> 1 model.load_state_dict(torch.load("esm1_t12_85M_UR50S-Best.pt"))
~/anaconda3/envs/c10_stability/lib/python3.8/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1049
1050 if len(error_msgs) > 0:
-> 1051 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1052 self.__class__.__name__, "\n\t".join(error_msgs)))
1053 return _IncompatibleKeys(missing_keys, unexpected_keys)
This is a fairly standard pytorch error that you get when trying to load the wrong state_dict into a model. Upon further inspection, it looks like all keys in saved state_dict have a “model.” prefix on their names while the keys in the model itself do not (I attached a file of the full error which shows this mismatch in key names).
What do you recommend for bypassing this problem? Is there a specific GitHub repo tag that I should be trying to work off of? It looks like there might have been some change to the torch.hub models between my original training script and trying to load now. Would it be reasonable to just remove the “model.” prefix in the saved state_dict?
Thank you!
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (1 by maintainers)
Top GitHub Comments
Thank you for the help! We now have everything working.
@brucejwittmann Thanks a lot! will try to customize the output layers:)