ONNX T5 with Beam Search
See original GitHub issueHey guys, I didn’t know where this belonged so opening up a generic issue. I was working on integrating the ONNX T5 code by @abelriboulot with the HuggingFace Beam Search decoding code since I already had a decently performing T5 model for summarization and wanted to improve performance on CPU while maintaining the inference accuracy. It works for the most part, but is slower as the HF code uses cached past state values to speed up the decoding. I got around this issue by creating two decoders with lm-head, one which doesn’t take in past values for the initial decoding and another for subsequent steps where past values are considered. This is a bit complicated as the past values have to be flattened out to pass through the ONNX graph which I did and it works for getting back the output. But for passing the input parameters, I get the following error: RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:‘Mul_48’ Status Message: /Users/runner/work/1/s/onnxruntime/core/providers/cpu/math/element_wise_ops.h:479 void onnxruntime::BroadcastIterator::Init(int64_t, int64_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 2 by 3
I feel like I am close to the solution which could essentially be added to the repo but this error is tripping me up 😦 Any help whatsoever will be appreciated. Thanks @mfuntowicz @abelriboulot @patrickvonplaten @patil-suraj @sshleifer
ONNX Export code:
past_state_input_pre = torch.rand((1,12,1,64))
past_state_input_post = torch.rand((1, 12, 10, 64))
past_key_value_states = [(past_state_input_pre, past_state_input_pre, past_state_input_post, past_state_input_post) for i in range(12)]
past_val_outputs = {'past_states_op_'+str(i): {0:'batch', 2: 'sequence'} for i in range(48)}
past_val_inputs = {'past_states_ip' + str(i): {0: 'batch', 2: 'sequence'} for i in range(48)}
dynamix_axes_dict = {
'input_ids': {0:'batch', 1: 'sequence'},
'encoder_hidden_states': {0:'batch', 1: 'sequence'}
}
dynamix_axes_dict.update(past_val_inputs)
dynamix_axes_dict.update({'hidden_states': {0:'batch', 1: 'sequence'}})
dynamix_axes_dict.update(past_val_outputs)
output_names_list = ['hidden_states'] + ['past_states_op_' + str(i) for i in range(48)]
input_names_list = ['input_ids', 'encoder_hidden_states'] + ['past_states_ip' + str(i) for i in range(48)]
# Exports to ONNX
_ = torch.onnx.export(
decoder_with_lm_head,
(torch.tensor([[42]]), simplified_encoder(input_ids), past_key_value_states),
f"{output_prefix}-decoder-with-lm-head.onnx",
export_params=True,
opset_version=12,
input_names=input_names_list,
output_names=output_names_list,
dynamic_axes= dynamix_axes_dict)
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:9 (8 by maintainers)
Hi @amanpreet692 , I’m not sure what this error means, but I’ve
T5
onnx
version ready which is compatible withgenerate
method. To be able to use cache I exported theencoder
andlm_head
toonnx
and kept thedecoder
intorch
. This is bit hacky but still gives 1.4-1.6x speed-up for beam search, I’ll be sharing it soon.This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.