question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Wav2Vec2 CUDA memory usage doubled in v4.11.3 compared to v4.10.3 with the same batch size

See original GitHub issue

Environment info

  • transformers version: 4.11.3
  • Platform: Linux-5.11.0-40-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.8.1+cu111 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes, 3090
  • Using distributed or parallel set-up in script?: No

Who can help

@patrickvonplaten, @anton-l

Information

When using Wav2vec2 the memory usage roughly doubles when going from Huggingface v4.10.3 to v4.11.3 Whereas my 3090 (24GB memory) in v4.10.3 could handle a batchsize of ~32, in 4.11.3 this is reduced to ~10.

The problem arises when using:

  • my own modified scripts

The tasks I am working on is:

  • ASR

To reproduce

Steps to reproduce the behavior:

  1. Run script with v4.10 and v4.11 and watch CUDA memory usage

Reproduce script (relatively minimal):

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, TrainingArguments
from transformers.trainer import Trainer
from torch.utils.data.dataset import Dataset
import numpy as np

class ProcessedDataset(Dataset):
    def __init__(self, processor):
        self.processor = processor

    def __getitem__(self, i):
        x = np.ones(16000 * 10) # 10 seconds
        y = "this is a random sentence"
        with self.processor.as_target_processor():
            batch= {"labels": self.processor(y).input_ids}
        batch["input_values"] = self.processor(x, sampling_rate=16000).input_values
        return batch

    def __len__(self):
        return 10000

class DataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, features):
        input_features = [{"input_values": feature["input_values"][0]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        batch = self.processor.pad(
            input_features,
            padding=True,
            max_length=None,
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=True,
                max_length=None,
                pad_to_multiple_of=None,
                return_tensors="pt",
            )
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch


proc = Wav2Vec2Processor.from_pretrained("wietsedv/wav2vec2-large-xlsr-53-dutch")
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-nl-voxpopuli",
    attention_dropout=0,
    hidden_dropout=0,
    feat_proj_dropout=0,
    mask_time_prob=0,
    layerdrop=0,
    activation_dropout=0,
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=proc.tokenizer.pad_token_id,
    vocab_size=len(proc.tokenizer),
    ctc_zero_infinity=True
)
ds = ProcessedDataset(proc)
data_collator = DataCollator(processor=proc)
args = TrainingArguments(
    output_dir="/tmp/tmp_model",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    do_eval=False,
    num_train_epochs=1,
    fp16=True,
    group_by_length=False,
    save_steps=-1,
    eval_steps=1024,
    logging_steps=1024,
    warmup_steps=128,
    save_total_limit=1,
    dataloader_num_workers=1,
    seed=11
)

trainer = Trainer(model=model, args=args, train_dataset=ds, data_collator=data_collator)
trainer.train()

Expected behavior

Upgrading Huggingface Transformers from 4.10 to a later version should keep the memory usage in the same ballpark

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:12 (9 by maintainers)

github_iconTop GitHub Comments

2reactions
patrickvonplatencommented, Nov 15, 2021

Ok I think I already found one problem. It seems like the gradient_checkpointing PR refactor wasn’t 100% backward compatible.

@MarktHart - could you add

model.gradient_checkpointing_enable()

before this line:

trainer = Trainer(model=model, args=args, train_dataset=ds, data_collator=data_collator)

this should more or less solve the problem

1reaction
voidfulcommented, Dec 17, 2021

@voidful - can you provide a reproducible script here? 😃 Thanks a lot!

It turn out to be length issue on my custom dataset, simplify apply .filter can solve this problem~~~~ Sorry for misleading.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Wav2vec2.0 memory issue - Models - Hugging Face Forums
Is there any issue which is related to loading data to memory? I think it should not depend how much bigger data is...
Read more >
FineTune Wav2Vec2.0 , CUDA OOM · Issue #2633 - GitHub
The server has 10 GPUs but nvidia-smi doesn't seem to be saying there is no memory left. It reports that just between 10%...
Read more >
Resolving CUDA Being Out of Memory With Gradient ...
This is very simple actually, the idea is to make sure that every batch contains the same number of input data for every...
Read more >
Confusion about running out of memory on GPU (due to ...
It just doesn't make sense that the largest batch size I can use on 4 GPUs (K80, so 48GB total of RAM) is...
Read more >
Running Stable Diffusion on Your GPU with Less Than 10Gb ...
In my usage Colab and Colab Pro were similar, with plain Colab occasionally ... to use the tensor cores (also to double "effective"...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found