question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Documentation for exporting custom architecture to ONNX

See original GitHub issue

I have a custom Bert model for relation classification based on the R-BERT paper. The model performs well but is relatively slow on CPU, so I’d like to try exporting to ONNX.

The model inherits from BertPreTrainedModel and is relatively simple:

class BertForRelationClassification(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForRelationClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.cls_dropout = nn.Dropout(0.1)  
        self.ent_dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(config.hidden_size*3, self.config.num_labels)
        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, e1_mask=None, e2_mask=None, 
                        labels=None, position_ids=None, head_mask=None):
        outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
                            attention_mask=attention_mask, head_mask=head_mask)
        pooled_output = outputs[1]  
        sequence_output = outputs[0] 

        def extract_entity(sequence_output, e_mask):
            extended_e_mask = e_mask.unsqueeze(1)
            extended_e_mask = torch.bmm(
                extended_e_mask.float(), sequence_output).squeeze(1)
            return extended_e_mask.float()

        e1_h = self.ent_dropout(extract_entity(sequence_output, e1_mask))
        e2_h = self.ent_dropout(extract_entity(sequence_output, e2_mask))
        context = self.cls_dropout(pooled_output)
        pooled_output = torch.cat([context, e1_h, e2_h], dim=-1)

        logits = self.classifier(pooled_output)
        outputs = (logits,) + outputs[2:]

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        outputs = (loss,) + outputs
        return outputs 

Following the guidelines for exporting custom models to ONNX, I’ve created a custom OnnxConfig for it and specified inputs and outputs:


class BertForRelationClassificationOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict(
            [
                ("input_ids", {0: "batch", 1: "sequence"}),
                ("attention_mask", {0: "batch", 1: "sequence"}),
                ("token_type_ids", {0: "batch", 1: "sequence"}),
                ("labels", {0: "batch", 1: "sequence"}),
                ("e1_mask", {0: "batch", 1: "sequence"}),
                ("e2_mask", {0: "batch", 1: "sequence"}),
            ]
        )

    @property
    def outputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict([("outputs", {0: "batch", 1: "sequence"})]) 

However when I run convert_graph_to_onnx.py, (of course) the model is assumed to be BertModel and the inputs and outputs are those of the vanilla BertModel. I’m unclear on the next steps.

I’m fairly sure this is what I should do next as stated on the documentation page: “Once this is done, a single step remains: adding this configuration object to the initialisation of the model class, and to the general transformers initialisation”.

While I feel a bit dense I’m still not following how to make this work, as the BertForRelationClassificationOnnxConfig class I created doesn’t inherit from BertConfig (I could make it do so, but the documentation doesn’t specify this) so I don’t see how I can use this for initialization of the model. The MBart example doesn’t make sense to me as I’m not contributing to the transformers code base.

Can you please provide guidance or a specific example? Thank you!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
LysandreJikcommented, Oct 18, 2021

Having a duplicate for CamemBERT isn’t an issue 😃

1reaction
ndobbcommented, Oct 14, 2021

@ChainYo, thank you very much! That makes a lot of sense and I was clearly misunderstanding.

I’ll do as you suggest and, if it works, close the issue.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Export to ONNX - Transformers - Hugging Face
In this guide, we'll show you how to export Transformers models to ONNX (Open Neural Network eXchange). Once exported, a model can be...
Read more >
Tutorial 6: Exporting a model to ONNX
So far, our codebase supports onnx exporting from pytorch models trained ... --show : Determines whether to print the architecture of the exported...
Read more >
Best Practices for Neural Network Exports to ONNX
Exporting your model to ONNX helps you to decouple the (trained) model from the rest of your project. Moreover, exporting also avoids environment...
Read more >
ORT Mobile Model Export Helpers - ONNX Runtime
The ORT Mobile pre-built package only supports the most recent ONNX opsets in order to minimize binary size. Most ONNX models can be...
Read more >
(optional) Exporting a Model from PyTorch to ONNX and ...
proto documentation.). Then, onnx.checker.check_model(onnx_model) will verify the model's structure and confirm that the model has a valid schema. The ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found