How to use multiple PreTrainedModel models in a custom model?
See original GitHub issueDetails
I am using the Trainer to train a custom model, like this:
class MyModel(nn.Module):
def __init__(self,):
super(MyModel, self).__init__()
# I want the code to be clean so I load the pretrained model like this
self.bert_layer_1 = transformers.AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
self.bert_layer_2 = transformers.AutoModel.from_pretrained("bert-base-chinese")
self.other_layers = ... # not important
def forward(self,):
pass # not important
When running trainer.save_model()
, it will only save the model’s state, as the custom model is not a PreTrainedModel
(as the terminal shown below).
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
And when reloading the saved model on production, I need to initialize a new MyModel
and load its states, which is not so convenient. I hope to load this model using transformers.AutoModel.from_pretrained('MODEL_PATH')
like other PreTrainedModel
s.
I tried to change class MyModel(nn.Module)
to class MyModel(PreTrainedModel)
, but the PreTrainedModel
needs a PretrainedConfig
when initialized. I don’t have one in the current implementation, I don’t know how to manage the config when using multiple PreTrainedModel models. I want to keep the self.bert_layer_1
and self.bert_layer_2
as simple as from_pretrained
, not = BertModel(config)
.
Is there a way to do that?
Environment info
transformers
version: 4.9.2- Platform: macOS / Ubuntu
- Python version: 3.8.6
- PyTorch version (GPU?): 1.8.1 (False) / (yes)
- Tensorflow version (GPU?): 2.4.1 (False) / (yes)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: parallel
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (3 by maintainers)
@sgugger Could you give an example on how to subclass PreTrainedModel? I would also like to integrate my model at https://huggingface.co/maxpe/twitter-roberta-base_semeval18_emodetection better with the transformer library:
My attempt with
PyTorchModelHubMixin
didn’t work well.A model that is not inside the
transformers
library won’t work with the AutoModel API. To properly use the save/from pretrained methods, why not subclassingPreTrainedModel
instead ofnn.Module
?