[torch-ort-infer] Aten fallback doesn't work
See original GitHub issueAten op doesn’t fallback to native pytorch runtime as expected.
Versions: Torch - 1.12.0 OnnxRuntime - 1.12.0 Torch-ort-infer - 1.12.0
Reproduction steps:
import torch
from torch_ort import ORTInferenceModule
def test_numpy_T(input_shape):
class NeuralNet(torch.nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
def forward(self, input):
return input.T
device = "cpu"
ort_model = ORTInferenceModule(NeuralNet().to(device))
def run_step(model, input):
prediction = model(input)
return prediction
ort_input = torch.rand(input_shape, dtype=torch.float, device=device)
ort_prediction = run_step(ort_model, ort_input)
if __name__ == "__main__":
test_numpy_T([3, 2, 5])
Error log
Traceback (most recent call last): File “unit_test_atenop.py”, line 23, in <module> test_numpy_T([3, 2, 5]) File “unit_test_atenop.py”, line 20, in test_numpy_T ort_prediction = run_step(ort_model, ort_input) File “unit_test_atenop.py”, line 16, in run_step prediction = model(input) File “/ort_aten_fb/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1130, in _call_impl return forward_call(*input, **kwargs) File “/ort_aten_fb/lib/python3.8/site-packages/torch_ort/ortinferencemodule/_utils_infer.py”, line 98, in _forward return ortinferencemodule._forward_call(*inputs, **kwargs) File “/ort_aten_fb/lib/python3.8/site-packages/torch_ort/ortinferencemodule/ortinferencemodule.py”, line 107, in _forward_call self._inference_session = onnxruntime.InferenceSession( File “/ort_aten_fb/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py”, line 347, in init self._create_inference_session(providers, provider_options, disabled_optimizers) File “/ort_aten_fb/lib/python3.8/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py”, line 386, in create_inference_session sess = C.InferenceSession(session_options, self.model_bytes, False, self.read_config_from_model) onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (ATen_0) output arg (data) type inference failed.
Tested with symbolic shape inference call from ORTModule(ref: symbolic_shape). Fails with Exception(“Incomplete symbolic shape inference”).
Issue Analytics
- State:
- Created a year ago
- Comments:6 (1 by maintainers)
Top GitHub Comments
@askhade Yes, sure
@natke : Can you add this example in this repo? Explaining how to set the type when type inference fails in ORT.