question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

ONNX T5 with Beam Search

See original GitHub issue

Hey guys, I didn’t know where this belonged so opening up a generic issue. I was working on integrating the ONNX T5 code by @abelriboulot with the HuggingFace Beam Search decoding code since I already had a decently performing T5 model for summarization and wanted to improve performance on CPU while maintaining the inference accuracy. It works for the most part, but is slower as the HF code uses cached past state values to speed up the decoding. I got around this issue by creating two decoders with lm-head, one which doesn’t take in past values for the initial decoding and another for subsequent steps where past values are considered. This is a bit complicated as the past values have to be flattened out to pass through the ONNX graph which I did and it works for getting back the output. But for passing the input parameters, I get the following error: RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:‘Mul_48’ Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/math/element_wise_ops.h:479 void onnxruntime::BroadcastIterator::Init(int64_t, int64_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 2 by 3

I feel like I am close to the solution which could essentially be added to the repo but this error is tripping me up 😦 Any help whatsoever will be appreciated. Thanks @mfuntowicz @abelriboulot @patrickvonplaten @patil-suraj @sshleifer

ONNX Export code:

    past_state_input_pre = torch.rand((1,12,1,64))
    past_state_input_post = torch.rand((1, 12, 10, 64))
    past_key_value_states = [(past_state_input_pre, past_state_input_pre, past_state_input_post, past_state_input_post) for i in range(12)]

    past_val_outputs = {'past_states_op_'+str(i): {0:'batch', 2: 'sequence'} for i in range(48)}
    past_val_inputs = {'past_states_ip' + str(i): {0: 'batch', 2: 'sequence'} for i in range(48)}
    dynamix_axes_dict = {
                              'input_ids': {0:'batch', 1: 'sequence'},
                              'encoder_hidden_states': {0:'batch', 1: 'sequence'}
                            }
    dynamix_axes_dict.update(past_val_inputs)
    dynamix_axes_dict.update({'hidden_states': {0:'batch', 1: 'sequence'}})
    dynamix_axes_dict.update(past_val_outputs)
    output_names_list = ['hidden_states'] + ['past_states_op_' + str(i) for i in range(48)]
    input_names_list = ['input_ids', 'encoder_hidden_states'] + ['past_states_ip' + str(i) for i in range(48)]
    # Exports to ONNX
    _ = torch.onnx.export(
                            decoder_with_lm_head,
                            (torch.tensor([[42]]), simplified_encoder(input_ids), past_key_value_states),
                                   f"{output_prefix}-decoder-with-lm-head.onnx",
                                   export_params=True,
                            opset_version=12,
                            input_names=input_names_list,
                            output_names=output_names_list,
                            dynamic_axes= dynamix_axes_dict)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:1
  • Comments:9 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
patil-surajcommented, Oct 31, 2020

Hi @amanpreet692 , I’m not sure what this error means, but I’ve T5 onnx version ready which is compatible with generate method. To be able to use cache I exported the encoder and lm_head to onnx and kept the decoder in torch. This is bit hacky but still gives 1.4-1.6x speed-up for beam search, I’ll be sharing it soon.

0reactions
stale[bot]commented, Jan 2, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Result difference when running beam search on ONNX T5 ...
I converted a T5 base model to ONNX and I implemented the IO bindings method for inference. When I ran the greedy search...
Read more >
Speeding up T5 inference - Transformers
In my experiments this gave ~1.4-1.6x speed-up with beam search. The first time you call OnnxT5 it'll load the model from the hub,...
Read more >
Using beam search with the TensorRT compiled T5 model?
I have been using the code in TensorRT (TensorRT/demo/HuggingFace/T5) that builds decoder and encoder engines from the HuggingFace T5 model.
Read more >
Does converting a seq2seq NLP model to the ONNX format ...
for onnx seq2seq model, you need to implement model.generate() method ... it implements both greedy and beam search for t5. for bart have...
Read more >
Suraj Patil on Twitter: "Want to speed-up T5 generation ? head ...
... convert T5 to onnx and make it compatible with generate method to be able to use beam search,sampling etc and gives ~1.4-1.6x...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found