Documentation for exporting custom architecture to ONNX
See original GitHub issueI have a custom Bert model for relation classification based on the R-BERT paper. The model performs well but is relatively slow on CPU, so I’d like to try exporting to ONNX.
The model inherits from BertPreTrainedModel
and is relatively simple:
class BertForRelationClassification(BertPreTrainedModel):
def __init__(self, config):
super(BertForRelationClassification, self).__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.cls_dropout = nn.Dropout(0.1)
self.ent_dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(config.hidden_size*3, self.config.num_labels)
self.init_weights()
def forward(self, input_ids, token_type_ids=None, attention_mask=None, e1_mask=None, e2_mask=None,
labels=None, position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
sequence_output = outputs[0]
def extract_entity(sequence_output, e_mask):
extended_e_mask = e_mask.unsqueeze(1)
extended_e_mask = torch.bmm(
extended_e_mask.float(), sequence_output).squeeze(1)
return extended_e_mask.float()
e1_h = self.ent_dropout(extract_entity(sequence_output, e1_mask))
e2_h = self.ent_dropout(extract_entity(sequence_output, e2_mask))
context = self.cls_dropout(pooled_output)
pooled_output = torch.cat([context, e1_h, e2_h], dim=-1)
logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:]
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs
Following the guidelines for exporting custom models to ONNX, I’ve created a custom OnnxConfig for it and specified inputs and outputs:
class BertForRelationClassificationOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
("labels", {0: "batch", 1: "sequence"}),
("e1_mask", {0: "batch", 1: "sequence"}),
("e2_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict([("outputs", {0: "batch", 1: "sequence"})])
However when I run convert_graph_to_onnx.py
, (of course) the model is assumed to be BertModel
and the inputs and outputs are those of the vanilla BertModel
. I’m unclear on the next steps.
I’m fairly sure this is what I should do next as stated on the documentation page: “Once this is done, a single step remains: adding this configuration object to the initialisation of the model class, and to the general transformers
initialisation”.
While I feel a bit dense I’m still not following how to make this work, as the BertForRelationClassificationOnnxConfig
class I created doesn’t inherit from BertConfig
(I could make it do so, but the documentation doesn’t specify this) so I don’t see how I can use this for initialization of the model. The MBart example doesn’t make sense to me as I’m not contributing to the transformers code base.
Can you please provide guidance or a specific example? Thank you!
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (4 by maintainers)
Having a duplicate for CamemBERT isn’t an issue 😃
@ChainYo, thank you very much! That makes a lot of sense and I was clearly misunderstanding.
I’ll do as you suggest and, if it works, close the issue.