question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

torch.nn.DataParallel causes strange GPU memory overflow

See original GitHub issue

Thanks for your great job! When i am testing this model with code as

import torch
from reformer_pytorch import ReformerLM
from torch.nn import functional as F

model = ReformerLM(
    num_tokens=20000,
    dim=1024,
    depth=24,
    max_seq_len=1024,
    heads=16,
    lsh_dropout=0.1,
    emb_dim=1024,  # embedding factorization for further memory savings
    causal=True,  # auto-regressive or not
    bucket_size=64,  # average size of qk per bucket, 64 was recommended in paper
    n_hashes=8,  # 4 is permissible per author, 8 is the best but slower
    ff_chunks=200,  # number of chunks for feedforward layer, make higher if there are memory issues
    weight_tie=False,  # tie parameters of each layer for no memory per additional depth
    attn_chunks=8,  # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
    num_mem_kv=0,  # persistent learned memory key values, from all-attention paper
    twin_attention=False,  # both branches of the reversible network will be attention
    use_full_attn=True,  # use full self attention, for comparison
    full_attn_thres=128,  # use full attention if context length is less than set value
    use_scale_norm=False  # use scale norm from 'Transformers without tears' paper
).cuda()

model = torch.nn.DataParallel(model)
model.train()
x = torch.randint(0, 20000, (8, 1024)).long().cuda()
y = torch.randint(0, 20000, (8, 1024)).long().cuda()
pred = model(x)
loss = F.cross_entropy(pred.transpose(1, 2), y, reduction='mean')
loss.backward()
import ipdb
ipdb.set_trace()

When without model = torch.nn.DataParallel(model), 7616M memory is used. But after I add model = torch.nn.DataParallel(model), it causes OOV while 8 gpus has 16GB memory for each. I think maybe it is the problem of revtorch?

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:13 (10 by maintainers)

github_iconTop GitHub Comments

5reactions
zblosscommented, Feb 15, 2020

I have also been running into memory issues, and I plan to remove this nn.DataParellel until we can get it sorted out.

In the meantime, I am working on adding support for Microsoft’s new DeepSpeed library, which may take care of all of this for us.

1reaction
Phirefly9commented, Feb 25, 2020

I would recommend that dataparallel pretty much never be used for models and data of the size this network is targeting. Distributed is usually always faster even for smaller numbers of GPUs in my experience

Read more comments on GitHub >

github_iconTop Results From Across the Web

Uneven usage of GPU memory by DataParallel causes out-of ...
I am running my NN model using DataParallel on 3 GPUs. In one GPU, the memory usage rises to 12gb and as a...
Read more >
Possible memory leak when using nn.DataParallel #23 - GitHub
I believe Torch's nn.DataParallel respawns the worker threads between every epoch, and even though this should destroy the contexts and release GPU memory, ......
Read more >
Why torch.nn.DataParallel causes unexpected CUDA OOM?
DataParallel doing anything weird with the memory, so that it has less memory available than N1 GPU memory? or is torch.nn.
Read more >
[P] Eliminate PyTorch's `CUDA error: out of memory ... - Reddit
It only changes the batch size used for gpu operations, in the end everything is still accumulated/concatinated properly.
Read more >
PyTorch 101, Part 4: Memory Management and Using Multiple ...
Emptying Cuda Cache ; import torch from ; for x in ; # reduce the size of tensor if you are getting OOM...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found