Memory accumulates when training in a loop
See original GitHub issueThe problem is that GPU memory allocated accumulates for each run. This eventually results in a RuntimeError: CUDA out of memory
error. You can see the wandb GPU memory allocated, produced by the code below, here: wandb
I had the same problem when using Trainer’s built in hyperparameter_search, which also runs training in a loop I assume. Similar issues from the past are: https://github.com/huggingface/transformers/issues/1742 https://github.com/huggingface/transformers/issues/1134 https://gitmemory.com/issue/huggingface/transformers/9929/770965726
Environment info
transformers
version: 4.4.2- Platform: Linux-4.15.0-128-generic-x86_64-with-glibc2.10
- Python version: 3.8.8
- PyTorch version (GPU?): 1.8.0 (True)
- Tensorflow version (GPU?): not installed (NA)
- Using GPU in script?: I don’t explicitly use GPU but I assume the Trainer object does. See code below
- Using distributed or parallel set-up in script?: No
Who can help
Library:
- trainer: @sgugger
Information
Model I am using (Bert, XLNet …): BertForSequenceClassification.from_pretrained('bert-base-cased')
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- my own task or dataset: (give details below) I have my own dataset, but I’ve reproduced the issue wtih the Amazon polarity dataset from huggingface’s datasets
To reproduce
Steps to reproduce the behavior:
- Create Trainer object in a loop
- Run training in the loop
This code reproduces the error.
from transformers import (
BertForSequenceClassification,
BertTokenizer,
Trainer,
TrainingArguments,
BertConfig,
)
from datasets import load_dataset
from torch.utils.data import Dataset
import torch as th
import wandb
import os
class AmazonDataset(Dataset):
def __init__(self, data, tokenizer, max_len):
self.tokenizer = tokenizer
self.text = data['content']
self.labels = data['label']
self.max_len = max_len
self.n_datapoints = len(self.labels)
def __len__(self):
return self.n_datapoints
def __getitem__(self, idx):
text = self.text[idx]
assert type(text) is str
inputs = self.tokenizer(
text=text,
text_pair=None,
add_special_tokens=True,
padding='max_length',
truncation=True,
max_length=self.max_len,
return_tensors='pt'
)
return {
'input_ids': th.flatten(inputs['input_ids']).type(th.long),
'token_type_ids': th.flatten(
inputs['token_type_ids']).type(th.long),
'attention_mask': th.flatten(
inputs['attention_mask']).type(th.long),
'labels': th.tensor(self.labels[idx], dtype=th.long)
}
def model_init():
return BertForSequenceClassification.from_pretrained(
MODEL_NAME, return_dict=True
)
if __name__ == '__main__':
os.environ['WANDB_WATCH'] = 'all'
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
dataset = load_dataset('amazon_polarity')
train = AmazonDataset(
data=dataset['train'][:5000],
tokenizer=tokenizer,
max_len=300
)
test = AmazonDataset(
data=dataset['test'][:500],
tokenizer=tokenizer,
max_len=300
)
MODEL_NAME = 'bert-base-cased'
N_EPOCHS = 1
warmup_steps = int(len(train)*N_EPOCHS)
for i in range(10):
training_args = TrainingArguments(
output_dir='output',
do_train=True,
do_eval=True,
evaluation_strategy='steps',
learning_rate=2e-5,
weight_decay=0.1,
logging_steps=50,
per_device_eval_batch_size=30,
per_device_train_batch_size=15,
seed=1,
num_train_epochs=N_EPOCHS,
disable_tqdm=True,
report_to=['wandb'],
load_best_model_at_end=False,
lr_scheduler_type='linear',
warmup_steps=warmup_steps
)
model_config = BertConfig(
vocab_size=tokenizer.vocab_size,
pretrained_model_name_or_path=MODEL_NAME,
num_labels=2,
return_dict=True
)
trainer = Trainer(
args=training_args,
train_dataset=train,
eval_dataset=test,
tokenizer=tokenizer,
model_init=model_init
)
run = wandb.init(
project='Bug',
name=f'Bug{i}'
)
trainer.train()
run.finish()
Expected behavior
The loops runs without memory accumulating for each run.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:5 (4 by maintainers)
Top GitHub Comments
cc @borisdayma so you are aware.
I think the memory problem comes from the wandb integration. I do not see the problem without it: memory resets at 0 at each new step of the loop and goes back to the same max value.