Same training time for different values of sliding window in Longformer
See original GitHub issueSystem Info
Transformers: 4.20.1 Python: 3.8.12 Pretrained models & tokenizer from HF: “allenai/longformer-base-4096”
The training time does not change for any value of sliding window. For e.g. a sliding window of 2 or 512 (which is the default) or 1024 takes the same training time. This seems to be a bug to me. I need a very small local window span (sliding window max 64 across 4096 tokens) and the model is simply unusable in this scenario due to excessive training time
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
A simple: model.config.attention_window = [SLIDE_WIN_ATTN]*12
Expected behavior
I would expect training time to fall somewhat quadratically for lower values of SLIDE_WIN_ATTN (say for 64) as compared to the default which is 512. However the training time for both cases is the same (around 24 hours per epoch). In fact SLIDE_WIN_ATTN values from 2 to 1024 roughly take the same training time which should not be the case
Issue Analytics
- State:
- Created a year ago
- Comments:7
Top GitHub Comments
Hi @allohvk , I know you are talking about the training time. However, even with just the
forward
method of the model, we already see that the effect ofwindow_size
(used for local attentions), i.e. to have linear time instead of quadratic time, will appear only for large enoughwindow_size
(and therefore with long enough sequences).For small
window_size
, some overhead will prevent this much desired effect. From this observation, I am afraid that this holds for training too.If you try to measure this line directly https://github.com/huggingface/transformers/blob/8a61fe023430115bb61ec328a29d35571f4fc2c4/src/transformers/models/longformer/modeling_longformer.py#L820
(without any other parts, and therefore no other overhead), you will see this linear/quadratic running time.