question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

'assert num_heads > 0' error with DistilBert

See original GitHub issue

I get the following error when I try to optimize distilbert:

AssertionError                            Traceback (most recent call last)
<timed eval> in <module>

/opt/conda/lib/python3.7/site-packages/transformer_deploy/convert.py in main(input_args)
    245             onnx_path=onnx_model_path,
    246             onnx_optim_fp16_path=onnx_optim_fp16_path,
--> 247             use_cuda=True,
    248         )
    249         onnx_model = create_model_for_provider(path=onnx_optim_fp16_path, provider_to_use="CUDAExecutionProvider")

/opt/conda/lib/python3.7/site-packages/transformer_deploy/backends/ort_utils.py in optimize_onnx(onnx_path, onnx_optim_fp16_path, use_cuda)
     72         num_heads=0,  # automatic detection don't work with opset 13
     73         hidden_size=0,  # automatic detection
---> 74         optimization_options=optimization_options,
     75     )
     76 

/opt/conda/lib/python3.7/site-packages/onnxruntime/transformers/optimizer.py in optimize_model(input, model_type, num_heads, hidden_size, optimization_options, opt_level, use_gpu, only_onnxruntime)
    289 
    290     if not only_onnxruntime:
--> 291         optimizer.optimize(optimization_options)
    292 
    293     # Remove the temporary model.

/opt/conda/lib/python3.7/site-packages/onnxruntime/transformers/onnx_model_bert.py in optimize(self, options, add_dynamic_axes)
    317             if options is not None:
    318                 self.attention_mask.set_mask_format(options.attention_mask_format)
--> 319             self.fuse_attention()
    320 
    321         self.fuse_shape()

/opt/conda/lib/python3.7/site-packages/onnxruntime/transformers/onnx_model_bert.py in fuse_attention(self)
     52 
     53     def fuse_attention(self):
---> 54         self.attention_fusion.apply()
     55 
     56     def fuse_gelu(self):

/opt/conda/lib/python3.7/site-packages/onnxruntime/transformers/fusion_base.py in apply(self)
     41                     raise Exception("Can not find node in any graphs")
     42                 self.this_graph_name = graph.name
---> 43                 self.fuse(node, input_name_to_nodes, output_name_to_node)
     44 
     45         op_list = [node.op_type for node in self.nodes_to_add]

/opt/conda/lib/python3.7/site-packages/onnxruntime/transformers/fusion_attention.py in fuse(self, normalize_node, input_name_to_nodes, output_name_to_node)
    444             new_node = self.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v,
    445                                                   q_num_heads, self.hidden_size, root_input,
--> 446                                                   attention_last_node.output[0], add_qk_str)
    447             if new_node is None:
    448                 return

/opt/conda/lib/python3.7/site-packages/onnxruntime/transformers/fusion_attention.py in create_attention_node(self, mask_index, q_matmul, k_matmul, v_matmul, q_add, k_add, v_add, num_heads, hidden_size, input, output, add_qk_str)
    161             Union[NodeProto, None]: the node created or None if failed.
    162         """
--> 163         assert num_heads > 0
    164 
    165         if hidden_size > 0 and (hidden_size % num_heads) != 0:

AssertionError: 

While trying to resolve the issue, I observed that it did not occur when optimizer from onnxruntime-tools was used with opt_level 99 (instead of the one in onnxruntime.transformers). But the code then threw Exceptions due to some skip layer normalization issues.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
vishalsraocommented, Dec 12, 2021

Thank you, @pommedeterresautee . It is working now!

0reactions
pommedeterresauteecommented, Dec 12, 2021

@vishalsrao I have been able to reproduce the crash on my machine with ORT 1.9.0. A workaround is to set the number of heads to 12 (it will work on ORT 1.9.0).

    optimized_model: BertOnnxModel = optimizer.optimize_model(
        input=onnx_path,
        model_type="bert",
        use_gpu=use_cuda,
        opt_level=1,
        num_heads=12,  # replace 0 (autodetection) by 12
        hidden_size=0,
        optimization_options=optimization_options,
    )

It’s related to the bug linked above (which have been reported for ORT version 1.10.0), there is an issue in positional encoding of distillbert, I have found other related issues. So, hopefully fixing that bug in ORT, we switch to 1.10, and number of head detection rework!

Let me know if the workaround works for you.

image

Read more comments on GitHub >

github_iconTop Results From Across the Web

Version 1.10 introduces a bug making transformer graph ...
Describe the bug When I use ORT 1.10, optimize_model function to optimize Transformers model crash (issues during operator fusion) Traceback ...
Read more >
transformers.models.distilbert.modeling_distilbert
Source code for transformers.models.distilbert.modeling_distilbert ... Dropout(p=config.attention_dropout) assert self.dim % self.n_heads == 0 self.q_lin ...
Read more >
assert n_state % config.n_head == 0 error - Stack Overflow
I think the error message is pretty clear: assert n_state % config.n_head == 0. Tracing it back through the code, we can see....
Read more >
PyTorch-Transformers
An open source machine learning framework that accelerates the path from research prototyping to production deployment.
Read more >
runtimeerror: cuda error: cublas_status_not_initialized when ...
Hi, I tried to add some other embeddings in your BertEmbedding source code and then load the pretrained weights 'bert-base-chinese'.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found