question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

CUDA out of memory with `max_batch_size=1` using unconditional image-to-image

See original GitHub issue

Based on the README usage instructions, except with max_batch_size=1 running on Windows:

import torch
from imagen_pytorch import Imagen, ImagenTrainer, SRUnet256, Unet

# unets for unconditional imagen

unet1 = Unet(
    dim=32,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=3,
    layer_attns=(False, True, True),
    layer_cross_attns=(False, True, True),
    use_linear_attn=True,
)

unet2 = SRUnet256(
    dim=32,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 8),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    condition_on_text=False,  # this must be set to False for unconditional Imagen
    unets=(unet1, unet2),
    image_sizes=(64, 128),
    timesteps=1000,
)

trainer = ImagenTrainer(imagen).cuda()

# now get a ton of images and feed it through the Imagen trainer

training_images = torch.randn(4, 3, 256, 256).cuda()

# train each unet in concert, or separately (recommended) to completion

for u in (1, 2):
    loss = trainer(training_images, unet_number=u, max_batch_size=1)
    trainer.update(unet_number=u)

# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)

images = trainer.sample(batch_size=16)  # (16, 3, 128, 128)

The OOM error occurs during the SRUnet (set a breakpoint and checked)

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 6.00 GiB total capacity; 4.26 GiB already allocated; 0 bytes free; 4.31 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\autograd\__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\imagen_pytorch\trainer.py", line 508, in forward
    self.scale(loss, unet_number = unet_number).backward()
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\imagen_pytorch\trainer.py", line 98, in inner
    out = fn(model, *args, **kwargs)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\sterg\Documents\GitHub\sparks-baird\xtal2png\scripts\imagen_pytorch_example.py", line 41, in <module>
    loss = trainer(training_images, unet_number=u, max_batch_size=1)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 97, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 268, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,

I’m using an NVIDIA GeForce RTX 2060:

Type Value
GPU Architecture Turing
RTX-OPS 37T
Boost Clock 1680 MHz
Frame Buffer 6GB GDDR6
Memory Speed 14 Gbps

See also #12

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:8 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
lucidrainscommented, Jul 14, 2022

should be better now that only one unet is loaded into memory at any given time

if it still OOMs, you should buy a better graphics card

1reaction
lupinetinecommented, Jun 17, 2022

I’m hitting the same issue with 12GB, and have also run into this issue on an 40GB A100 that I used to verify. I have trained and inferred successfully on 8-16GB on versions up until 0.7. I jumped from 0.3 to 0.7 so I’ll have to backtrack to find the last time this was working successfully on my equipment.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to avoid "CUDA out of memory" in PyTorch - Stack Overflow
Although import torch torch.cuda.empty_cache(). provides a good alternative for clearing the occupied cuda memory and we can also manually ...
Read more >
Resolving CUDA Being Out of Memory With Gradient ...
Implementing gradient accumulation and automatic mixed precision to solve CUDA out of memory issue when training big deep learning models ...
Read more >
Solving the “RuntimeError: CUDA Out of memory” error
When using multi-gpu systems I'd recommend using the `CUDA_VISIBLE_DEVICES` environment variable to select the GPU to use. $ export ...
Read more >
Compute Sanitizer User Manual - NVIDIA Documentation Center
The tool can precisely detect and report out of bounds and misaligned memory accesses to global, local and shared memory in CUDA applications....
Read more >
CUDA out of Memory after few epochs - PyTorch Forums
After 4 epochs I am getting error CUDA out of Memory I am using Wav2Vec2 HuggingFace Model with PyTorch Training Setup Cuda Memory...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found