Issue to use GPT2 ONNX export with past key values
See original GitHub issueSystem Info
python: 3.10.6
platform: Ubuntu 22.10
optimum version: 1.5.1
onnxruntime: 1.13.1
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, β¦) - My own task or dataset (give details below)
Reproduction
Command line to export a GPT2 model:
python -m optimum.exporters.onnx --model gpt2 --task causal-lm-with-past output/
Gives the following output logs:
Framework not specified. Using pt to export to ONNX.
Using framework PyTorch: 1.13.0+cu117
Overriding 2 configuration item(s)
- use_cache -> True
- pad_token_id -> 0
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:796: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if batch_size <= 0:
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:185: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
attn_weights = attn_weights / torch.tensor(
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:185: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
attn_weights = attn_weights / torch.tensor(
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py:200: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
Validating ONNX model...
-[β] ONNX model output names match reference model (present.1.value, present.0.key, present.6.key, present.6.value, present.5.value, present.8.key, present.0.value, present.2.key, present.5.key, present.10.key, present.9.value, present.10.value, logits, present.4.value, present.7.key, present.11.value, present.3.value, present.3.key, present.4.key, present.2.value, present.1.key, present.9.key, present.11.key, present.8.value, present.7.value)
- Validating ONNX Model output "logits":
-[β] (2, 16, 50257) matches (2, 16, 50257)
-[x] values not close enough, max diff: 0.0013427734375 (atol: 1e-05)
- Validating ONNX Model output "present.0.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.0.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.1.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.1.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.2.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.2.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.3.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.3.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.4.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.4.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.5.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.5.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.6.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.6.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.7.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.7.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.8.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.8.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.9.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.9.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.10.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.10.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.11.key":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
- Validating ONNX Model output "present.11.value":
-[β] (2, 12, 32, 64) matches (2, 12, 32, 64)
-[β] all values close (atol: 1e-05)
An error occured, but the model was saved at: model_repository/gpt2/1/model.onnx
Eventhough there is an error in the close values validation, thatβs ok. Now I would like to run the model with the following Python:
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import GPT2Tokenizer
model = ORTModelForCausalLM.from_pretrained("output/", from_transformers=False, use_cache=True)
tokenizer = GPT2Tokenizer.from_pretrained("output/")
tokens = tokenizer("My name is Julien and I like", return_tensors="pt")
outputs_model = model.generate(**tokens)
And I get the following error:
/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/generation_utils.py:1359: UserWarning: Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to 20 (`self.config.max_length`). Controlling `max_length` via the config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.
warnings.warn(
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/generation_utils.py", line 1490, in generate
return self.greedy_search(
File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/transformers/generation_utils.py", line 2233, in greedy_search
outputs = self(
File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/modeling_base.py", line 60, in __call__
return self.forward(*args, **kwargs)
File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/optimum/onnxruntime/modeling_ort.py", line 1454, in forward
outputs = self.model.run(None, onnx_inputs)
File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 196, in run
raise ValueError("Model requires {} inputs. Input Feed contains {}".format(num_required_inputs, num_inputs))
ValueError: Model requires 26 inputs. Input Feed contains 2
Do I have to randomly feed myself the past_key_values.X.value
and past_key_values.X.keys
?
When I try to do this directly with onnxruntime, I also get an error. Here what I do:
import onnxruntime as ort
from transformers import GPT2Tokenizer
import numpy as np
sess = ort.InferenceSession('output/model.onnx', providers=["CPUExecutionProvider"])
tokenizer = GPT2Tokenizer.from_pretrained("output/")
tokens = dict(tokenizer("My name is Julien and I like", return_tensors="np"))
shape = (1, 12, len(tokens["input_ids"][0]), 64)
for i in range(12):
tokens[f"past_key_values.{i}.key"] = np.random.uniform(0, 1, shape).astype(np.float32)
tokens[f"past_key_values.{i}.value"] = np.random.uniform(0, 1, shape).astype(np.float32)
sess.run(None, tokens)
And I get the following error:
2022-12-06 16:42:17.603173515 [E:onnxruntime:, sequential_executor.cc:369 Execute] Non-zero status code returned while running Add node. Name:'/transformer/h.0/attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 16
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/jplu/anaconda3/envs/transformers/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:'/transformer/h.0/attn/Add' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:503 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 8 by 16
Expected behavior
I expect to have a proper generation and usage with onnxruntime. The final goal is to use it through a Triton server.
I certainly miss something, but the documentation is not clear on how to properly use seq2seq and causal-lm with past-key-values either directly with onnxruntime or with optimum.
Thanks a lot in advance for all the advices you could provide π
Issue Analytics
- State:
- Created 9 months ago
- Comments:13 (6 by maintainers)
Top GitHub Comments
cc @michaelbenayoun we should add tests for the CLI
Perfect! Waiting a single week is perfectly OK π By curiosity I will test with the main branch if I succeed to get it work, and will let you know in this thread if I encounter any issue.
Indeed, the generation is the hardest part to handle, on my side basically I host all my ONNX models into a Triton server, and I have
TritonModelForXXXX
s like yourORTModelForXXXX
that handle gRPC calls and can be used with pipelines. It does the work but the counterpart is that it generates a lot of network calls. Thatβs why I want to investigate to use their Triton Python backend with optimum to see if it works better.The ideal world, the dream, would be indeed a true end-to-end model that handles tokenization+inference for simple encoders and in case of decoders and encoders-decoders models tok+inf+generation.