Difference in prediction output of a scikit-learn model and ONNX model for the same data.
See original GitHub issueSee the difference in output in following image.
The dataset is for binary classification. The code to generate the model is:
categorical_pipeline_onehot = Pipeline([
('OneHot', OneHotEncoder( handle_unknown='ignore'))
])
preprocessor = ColumnTransformer(
transformers=[('Onehot', categorical_pipeline_onehot, CATEGORICAL_ATTRIBUTES_ONEHOT)],
remainder='passthrough')
from xgboost import XGBClassifier
xgb = XGBClassifier()
clf = Pipeline(steps=[('preprocessor', preprocessor),('cls', xgb)])
clf.fit(df_train[predictors], df_train["IsSelected"])
update_registered_converter(
XGBClassifier, 'XGBoostXGBClassifier',
calculate_linear_classifier_output_shapes, convert_xgboost,
options={'nocl': [True, False], 'zipmap': [True, False]})
To convert this model to ONNX the code is:
def convert_dataframe_schema(df, drop=None):
inputs = []
for k, v in zip(df.columns, df.dtypes):
if drop is not None and k in drop:
continue
if v == 'int64':
t = Int64TensorType([None, 1])
elif v == 'float64':
t = FloatTensorType([None, 1])
else:
t = StringTensorType([None, 1])
inputs.append((k, t))
return inputs
try:
model_onnx_wholed = convert_sklearn(model, 'pipeline_with_onehot', inputs)
except Exception as e:
print(e)
with open("XGB_onehot_wholedata.onnx", "wb") as f:
f.write(model_onnx_wholed.SerializeToString())
While prediction I am using C++ APIs as well as Python for onnxruntime. The prediction output generated by the onnx model is all 0 for both C++ and Python which do not match with that of the trained sklearn model.
The ONNX model adds Cast operator for features. It may be due to this but I am not sure.
The C++ onnxruntime API log is as following:
2020-07-06 18:50:47.2890620 [W:onnxruntime:test, abi_session_options.cc:147 OrtApis::SetIntraOpNumThreads] Since openmp is enabled in this build, this API cannot be used to configure intra op num threads. Please use the openmp environment variables to control the number of threads.
Using Onnxruntime C++ API
2020-07-06 18:50:55.2296864 [I:onnxruntime:, inference_session.cc:174 onnxruntime::InferenceSession::ConstructorCommon] Creating and using per session threadpools since use_per_session_threads_ is true
2020-07-06 18:50:55.4219403 [I:onnxruntime:, inference_session.cc:830 onnxruntime::InferenceSession::Initialize] Initializing session.
2020-07-06 18:50:55.4220585 [I:onnxruntime:, inference_session.cc:848 onnxruntime::InferenceSession::Initialize] Adding default CPU execution provider.
2020-07-06 18:50:55.4224222 [I:onnxruntime:test, bfc_arena.cc:15 onnxruntime::BFCArena::BFCArena] Creating BFCArena for Cpu
2020-07-06 18:50:55.4224870 [V:onnxruntime:test, bfc_arena.cc:32 onnxruntime::BFCArena::BFCArena] Creating 21 bins of max chunk size 256 to 268435456
2020-07-06 18:50:55.4247204 [V:onnxruntime:, inference_session.cc:671 onnxruntime::InferenceSession::TransformGraph] Node placements
2020-07-06 18:50:55.4248273 [V:onnxruntime:, inference_session.cc:673 onnxruntime::InferenceSession::TransformGraph] All nodes have been placed on [CPUExecutionProvider].
2020-07-06 18:50:55.4249199 [I:onnxruntime:, session_state.cc:25 onnxruntime::SessionState::SetGraph] SaveMLValueNameIndexMapping
2020-07-06 18:50:55.4250567 [I:onnxruntime:, session_state.cc:70 onnxruntime::SessionState::SetGraph] Done saving OrtValue mappings.
2020-07-06 18:50:55.4266916 [I:onnxruntime:, session_state_initializer.cc:178 onnxruntime::SaveInitializedTensors] Saving initialized tensors.
2020-07-06 18:50:55.4268214 [I:onnxruntime:, session_state_initializer.cc:223 onnxruntime::SaveInitializedTensors] Done saving initialized tensors
2020-07-06 18:50:55.7311821 [I:onnxruntime:, inference_session.cc:919 onnxruntime::InferenceSession::Initialize] Session successfully initialized.
Number of inputs = 22
Number of outputs = 2
Output 0 : name=label
Output 1 : name=probabilities
2020-07-06 18:51:18.7355209 [I:onnxruntime:, sequential_executor.cc:145 onnxruntime::SequentialExecutor::Execute] Begin execution
2020-07-06 18:51:27.4858870 [I:onnxruntime:test, bfc_arena.cc:259 onnxruntime::BFCArena::AllocateRawInternal] Extending BFCArena for Cpu. bin_num:0 rounded_bytes:256
2020-07-06 18:51:27.4867688 [I:onnxruntime:test, bfc_arena.cc:143 onnxruntime::BFCArena::Extend] Extended allocation by 1048576 bytes.
2020-07-06 18:51:27.4871985 [I:onnxruntime:test, bfc_arena.cc:147 onnxruntime::BFCArena::Extend] Total allocated bytes: 1048832
2020-07-06 18:51:27.4872666 [I:onnxruntime:test, bfc_arena.cc:150 onnxruntime::BFCArena::Extend] Allocated memory at 000001D54AC25060 to 000001D54AD25060
Predicted Class:
0
0
0
0
0
0
0
0
0
0
0
0
Versions:
skl2onnx: 1.7.0
onnxruntime: 1.3.0
Issue Analytics
- State:
- Created 3 years ago
- Comments:12
Top Results From Across the Web
Scikit-learn model converted to ONNX results in different ...
I was trying to train a scikit-learn model in Python, export it to ONNX and then use the model for prediction in a...
Read more >One model, many possible conversions with options - ONNX
Every classifier is by design converted into an ONNX graph which outputs two results: the predicted label and the prediction probabilites for every...
Read more >Accelerate and simplify Scikit-learn model inference with ...
This blog post introduces how to operationalize scikit-learn with ONNX, sklearn-onnx, and ONNX Runtime.
Read more >9. Model persistence — scikit-learn 1.2.0 documentation
ONNX is a binary serialization of the model. It has been developed to improve the usability of the interoperable representation of data models....
Read more >Train, convert and predict with ONNX Runtime
We use module sklearn-onnx to convert the model into ONNX format. ... We load the model with ONNX Runtime and look at its...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
use as following:
Yes