question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

ONNXRuntimeError after enabled fp16 mixed precision training

See original GitHub issue

Hi folks,

I tested fp16 mixed precision training with ORTModule wrapped GPT2 model on a fine-tuning task. However, after enabling fp16, I encountered the following error:

Error Message

Traceback (most recent call last):
  File "test_onnxruntime_train.py", line 115, in test_ort_trainer
    train_result = trainer.train()
  File "/workspace/optimum/onnxruntime/trainer.py", line 498, in train
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 1984, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/trainer.py", line 2016, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/ortmodule.py", line 81, in _forward
    return self._torch_module.forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_torch_module_ort.py", line 32, in _forward
    return self._execution_manager(self.is_training()).forward(*inputs, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 265, in forward
    override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_fallback.py", line 194, in handle_exception
    raise exception
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_training_manager.py", line 85, in forward
    self._initialize_graph_builder(training=True)
  File "/usr/local/lib/python3.6/dist-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 420, in _initialize_graph_builder
    self._onnx_models.exported_model.SerializeToString(), grad_builder_config)
RuntimeError: /onnxruntime_src/orttraining/orttraining/python/orttraining_pybind_state.cc:707 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const pybind11::bytes&, const onnxruntime::training::OrtModuleGraphBuilderConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Type Error: Type parameter (T) of Optype (Where) bound to different types (tensor(float) and tensor(float16) in node (Where_183).

It seems that the exported ONNX graph is broken due to incompatible input types. I am wondering where comes the problem. Do any insight on that?


System information

Docker image built with the Dockerfile-cu11 in onnxruntime-training-examples.

  • OS: Ubuntu 18.04
  • CUDA/cuDNN version: 11/8
  • onnxruntime-training: 1.9.0+cu111
  • torch: 1.9.0+cu111
  • torch-ort: 1.9.0
  • Python version:3.6
  • GPU: A100

Additional Information

  • I actually have a work version under the environment: torch 1.8.1+torch-ort 1.9.0+onnxruntime-training1.11.0.dev20220113001+cu102, so I wonder if the error comes from the fact that what’s in the Dockerfile are outdated. However, I can’t find how to install onnxruntime-training1.11.0.dev20220113001+cu102 anymore.
  • Here is the onnx graph exported with DebugOptions, not sure if that could help image

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:8

github_iconTop GitHub Comments

2reactions
JingyaHuangcommented, Apr 7, 2022

Hi @ytaous and @baijumeswani ,

Thanks a lot for the reply, super glad to know the root of the error!!

I adopted the workaround suggested with onnx==1.10.2 and onnxruntime-training==1.11.0+cu113, and it works well for the previous models that we have tested! Although there is still the issue of mixed-precision training, I think that we will temporarily jump the benchmark for gpt2 until the next release with the fix integrated.

Thanks again for the help!

2reactions
baijumeswanicommented, Apr 6, 2022

@JingyaHuang I can confirm that that PR didn’t make it to ort release 1.11.0. Please use the work around for now until we have a more permanent solution in place.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Type Error when training Hugging Face Transformers GPT2 ...
0) GPT2 with onnxruntime training (model is wrapped in ORTModule) and fp16 mixed precision training enabled, I encounter a type error.
Read more >
Train With Mixed Precision - NVIDIA Documentation Center
Mixed precision methods combine the use of different numerical formats in one computational workload. This document describes the ...
Read more >
Mixed precision training - fastai
The solution: mixed precision training · compute the output with the FP16 model, then the loss. · multiply the loss by scale then...
Read more >
Chapter 8: Mixed Precision Training - DGL Docs
On a NVIDIA V100 (16GB) machine, training this model without fp16 consumes 15.2GB GPU memory; with fp16 turned on, the training consumes 12.8G...
Read more >
Auto Mixed Precision Training - Colossal-AI
Hands-on Practice. AMP Introduction​. Automatic Mixed Precision training is a mixture of FP16 and FP32 training. Half-precision float point ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found