question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Resume training uses way more VRAM than fresh training, triggering OOM

See original GitHub issue

Hey again!

I noticed that after trying to resume my training (v1.0.0) I was getting OOM, but this was the same checkpoint I was just training on minutes earlier.

Here’s a minimal example that triggers OOM. Run it once to generate the checkpoint and run it again to OOM. You may have to adjust it to your own VRAM limits (i’m using a 3090)

# Run once to generate checkpoint. VRAM sitting at a safe 21/24GB
# Run again to load from existing checkpoint causing OOM


import torch, os
from imagen_pytorch import BaseUnet64, SRUnet256, ImagenTrainer, ElucidatedImagen
from imagen_pytorch.data import Dataset

checkpoint = "./checkpoint.pt"
img_path = "../images/"
train_unet_num = 1
img_size = 64
batch_size = 128
max_batch_size = 16

unet1 = BaseUnet64(
    dim = 256,
)

unet2 = SRUnet256(
    dim = 128,
)

imagen = ElucidatedImagen(
    unets = (unet1, unet2),
    image_sizes = (64, 256),
    num_sample_steps = (19, 19),
    condition_on_text = False
)

print('Starting DataLoader')
data = Dataset(img_path, image_size = img_size)
dl = torch.utils.data.DataLoader(data,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=8)

trainer = ImagenTrainer(
    imagen,
    train_dl = dl,
    use_ema = False,
    only_train_unet_number = train_unet_num
).cuda()


if os.path.exists(checkpoint):
    print('Loading existing checkpoint')
    _ = trainer.load(checkpoint)
    last_step = _['steps'][train_unet_num - 1]


print('Training')
for i in range(2):
    loss = trainer.train_step(unet_number = train_unet_num, max_batch_size = max_batch_size)

print('Saving')
trainer.save(checkpoint)

If you see anything obviously wrong here, let me know. Thanks!

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
truftycommented, Jul 19, 2022

Haha, you have the opposite problem.

Ill assume it’s just my training script then, even though I mostly used the example code from the repo, especially considering your on the same minor version as me.

Thanks for helping with this. @lucidrains Feel free to close this if you want. Ill just lower my batch for now to get it to fit.

1reaction
truftycommented, Jul 19, 2022

If you were to restart training from scratch without an existing checkpoint, do you see a drop in total memory usage too?

(19, 19), or (19, 35) was recommended on the LAION discord, and the results seem fine at that time-step for Elucidated. (still need to try the SRUnet though)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Resume training out of memory #296 - dbolya/yolact - GitHub
These other programs can use more or less vram depending on the specific content onscreen. It could be that loading weights from memory...
Read more >
How to run Textual inversion locally (train your own AI) - Reddit
1) First you need a GPU with at least 10GB VRAM. My RTX 3060 12GB uses about 9.2GB while training the SD. 2)...
Read more >
out of memory - Tensorflow OOM on GPU - Stack Overflow
I have recently had a very similar error and it was due to accidentally having a training process running in the background ...
Read more >
Estimating GPU Memory Consumption of Deep Learning Models
Developers mainly use GPUs to accel- erate the training, testing, and deployment of DL models. However, the GPU memory consumed by a DL...
Read more >
Running Stable Diffusion on Your GPU with Less Than 10Gb ...
That's in 8G of VRAM. It worked well on the hlky branch using webui (built using gradio). Seeing that training etc. is much...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found