Irregular VRAM usage with gpt-neo inference with sequences longer than 250 tokens
See original GitHub issueEnvironment info
transformers
version: 4.5.1 / HEAD- Platform: Linux/Colab Pro
- Python version: 3.7
- PyTorch version (GPU?): 1.8.1 (CUDA 11.0)
- Tensorflow version (GPU?):
- Using GPU in script?: Yes, NVIDIA P100
- Using distributed or parallel set-up in script?:
Who can help
Information
Model I am using (Bert, XLNet …): EleutherAI/gpt-neo-2.7B
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- my own task or dataset: (give details below)
To reproduce
Steps to reproduce the behavior:
- Install transformers in a Colab Pro notebook
- Run this script to log peak memory usage for inference with increasing sequence length: https://gist.github.com/finetuneanon/7ce0ed5090a27a383abffbbbc0433a29
- Wait for it to crash with an OOM error in the attention matmul somewhere above sequence length 1850
Output:
1870 5436434432
ok 6535669248
1871 5436434432
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-2-f2aeed4489bd> in <module>()
21 return_dict_in_generate=True,
22 repetition_penalty=1.2,
---> 23 pad_token_id=tokenizer.eos_token_id
24 )
25 del ids
13 frames
/usr/local/lib/python3.7/dist-packages/transformers/models/gpt_neo/modeling_gpt_neo.py in _attn(self, query, key, value, causal_mask, masked_bias, attn_dropout, attention_mask, head_mask)
238 key = key.to(torch.float32)
239
--> 240 attn_weights = torch.matmul(query, key.transpose(-1, -2))
241 attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
242
RuntimeError: CUDA out of memory. Tried to allocate 4.59 GiB (GPU 0; 15.90 GiB total capacity; 9.75 GiB already allocated; 4.60 GiB free; 10.42 GiB reserved in total by PyTorch)
The full output can be found here: https://gist.github.com/finetuneanon/c7292ea676f57f5bb63803685d80bf5b
The output has the format:
sequence_length occupied_cuda_memory_before_inference
ok peak_occupied_cuda_memory_during_inference
Doing inference with real text has the same issue.
Expected behavior
I expected memory usage to increase steadily instead of jumping around wildly, but I am not sure if this might actually be the correct behaviour. If it is correct, reliably doing inference on long sequences on 16GB of VRAM seems to be impossible, but sometimes it works.
I have also plotted the peak memory allocation during inference:
The green line is peak memory allocation, the brown line is the amount of memory in use before running inference.
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (4 by maintainers)
Top GitHub Comments
Hi @finetuneanon
Thanks for the detailed issue!
So what is happening here is, the way local attention is designed is a bit weird (not the implementation), in that it splits the
seq_length
dim into(num_blocks, block_length)
but hereblock_length
is actually dynamic.It’s equal to
window_size
by default which is 256. But when theseq_length
is not evenly divisible byblock_length
then it’s adjusted as followssuch that, the
seq_length
becomes evenly divisible byblock_length
.So the shape of
query
becomes(batch, num_blocks, block_length, hidden_dim)
and then thekeys
andvalues
are padded and theseq_length
dim is split such that their shape becomes(batch, num_blocks, window_size + block_length, hidden_dim
).Here’s a simple function to get the shape of
query
andkey
for givenseq_length
Let’s print the shapes for few lengths
which gives
as you can see, because of the dynamic
block_length
the dimensions are very different for differentseq_length
which explains the irregular VRAM usage.if you set the seq_length to 1871 you’ll get
as you posted above.
So I wouldn’t say this is an implementation issue, that’s how the local attention algorithm is designed in mesh-tf.
Great, I ran a small test and it seems to be working! (sorry about the earlier comment). Here’s the script
I will run a few tests with the actual model and will let you know. If it works, feel free to open a PR 😃