Modify BERT/BERT-descendants to be TorchScript-able (not just traceable)
See original GitHub issue🚀 Feature request
Modify BERT models (src/transformers/modeling_bert.py) to conform to TorchScript requirements, so they can be jit.script()
-ed, not just jit.trace()
-ed (as is currently the only supported option)
Note: I have a working version implementing this, which I would like to contribute. See below.
Motivation
A scriptable model would allow for variable-length input, offering big speedup gains and simplification (no need to create different models for different input lengths).
In addition, it would avoid other potential pitfalls with tracing (e.g., code paths that are input dependent and not covered by the tracing example input).
Related issues: https://github.com/huggingface/transformers/issues/2417 https://github.com/huggingface/transformers/issues/1204 possibly also https://github.com/huggingface/transformers/issues/1477 https://github.com/huggingface/transformers/issues/902
Your contribution
I have a working PR that modifies all the models in src/transformers/modeling_bert.py and makes them TorchScript-able. I have not tested it on other models that use BERT components (e.g., albert), but it should be possible to expand the capability to those, as well.
However, it would require some significant work to make it ready for submission: besides formatting, documentation, testing etc., my current version changes the method signatures, and I would need to avoid that to maintain backward-compatibility.
Before putting in that work, I’d like to make sure that such a PR is something you’d be interested in and would be willing to merge in, assuming it meets the requirements.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:3
- Comments:20 (5 by maintainers)
Top GitHub Comments
My change is available at https://github.com/sbrody18/transformers/tree/scripting
Note that it is based off of a commit from earlier this month: https://github.com/huggingface/transformers/compare/ef0e9d806c51059b07b98cb0279a20d3ba3cbc1d...sbrody18:scripting Since then there have been changes made to the BertModel interface adding a return_tuple argument and changing the return type of the forward method, and this would require more effort to resolve.
I listed the principles I used in https://github.com/huggingface/transformers/issues/5067#issuecomment-644989375. The original components tended to return different sized tuples, depending on arguments, which is problematic for TorchScript. When a component BertX required an interface change to be scriptable, I made a BertScriptableX version with the modifications, and had the BertX component inherit from it and just modify the output so it is compatible with the original API.
I made scriptable versions of BertModel and all the BertFor<Task> classes, except BertForMaskedLM (some complexities there were too much work for a proof of concept). I added a test to demonstrate the scripting capability.
Note that my change disables the gradient_checkpoint path in the encoder. I think this can be resolved, but I didn’t have the time to work on it.
@sgugger Please see POC implementation in PR above.