question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Accelerate] Each Training step takes an absurd amount of time

See original GitHub issue

Training a VQ-VAE-2 on Kaggle’s GPU - when I start training on a full epoch, each iteration (i.e each step/pass on a single batch) takes around 5 mins on GPU. GPU usage is negligible, yet it does occupy all VRAM as CUDA should 🤷

This is my training function:-

def train(loader, val_loader, scheduler):
    accelerator = Accelerator(fp16=False, cpu=args.cpu_run)
    device = accelerator.device

    #initializing the model
    model = VQVAE(in_channel=3, channel=128, n_res_block=args.res_blocks,
                  n_res_channel=args.res_channel,
                  embed_dim=args.embed_dim, n_embed=args.n_embed,
                  decay=args.decay).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    accelerator.print(summary(model, (args.batch_size, 3, args.size, args.size)))
    accelerator.print(model) #vanilla pytorch summary

    model, optimizer, loader, val_loader = accelerator.prepare(model, optimizer, loader, val_loader)
    loader, val_loader = tqdm(loader), tqdm(val_loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    latent_loss_beta_list = torch.linspace(0, latent_loss_weight, 20)

    sample_size = 20

    mse_sum = 0
    mse_n = 0
    val_mse_sum, val_mse_n, beta_index = 0, 0, 0 #init beta index

    with wandb.init(project=args.wandb_project_name, config=args.__dict__, save_code=True, name=args.run_name, magic=True): 
        for epoch in range(args.epoch):
            #Starting Epoch loops
            model.train()
            for i, (img, label) in enumerate(loader):
                img = img.to(device)

                out, latent_loss = model(img)
                recon_loss = criterion(out, img)
                latent_loss = latent_loss.mean()

                beta_index = epoch  #for consistency
                if beta_index >= 24:
                    beta_index = latent_loss_beta_list[-1]

                loss = recon_loss + latent_loss_beta_list[beta_index] * latent_loss

                model.zero_grad(set_to_none=True)

                accelerator.backward(loss) #added loss to backprop

                if scheduler is not None:
                    scheduler.step()

                torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradclip) #grad clipping

                optimizer.step()

                mse_sum += recon_loss.item() * img.shape[0]
                mse_n += img.shape[0]

                lr = optimizer.param_groups[0]['lr']

                wandb.log({"epoch": epoch+1, "mse": recon_loss.item(), 
                            "latent_loss": latent_loss.item(), "avg_mse": (mse_sum / mse_n), 
                            "lr": lr})

                if i % 25 == 0:
                    accelerator.print({"epoch": epoch+1, "mse": recon_loss.item(),
                        "latent_loss": latent_loss.item(), "avg_mse": (mse_sum/ mse_n), 
                        "lr": lr})

                #Performing Validation and loggign out images
                if epoch % 2 == 0:   #i % 100 == 0
                    model.eval()

                    #--------------VALIDATION------------------
                    for i, (img, label) in enumerate(val_loader):
                        img.to(device)
                        model.to(device)

                        with torch.no_grad():
                            out, latent_loss = model(img)

                        val_recon_loss = criterion(out, img)
                        val_latent_loss = latent_loss.mean()
                        val_loss = recon_loss + latent_loss_beta_list[beta_index] * latent_loss

                        val_mse_sum += recon_loss.item() * img.shape[0]
                        val_mse_n += img.shape[0]

                    wandb.log({"epoch": epoch+1, "val_mse": val_recon_loss.item(), 
                            "val_latent_loss": val_latent_loss.item(), "val_avg_mse": (val_mse_sum/ val_mse_n), 
                            "lr": lr})

                    accelerator.print({"epoch": epoch+1, "val_mse": val_recon_loss.item(), 
                            "val_latent_loss": val_latent_loss.item(), "val_avg_mse": (val_mse_sum/ val_mse_n), 
                            "lr": lr})

                    model.train()

            #Saving the model checkpoints every epoch
            if epoch % 10 == 0:
                model.eval()
                
                sample = img[:sample_size]

                with torch.no_grad():
                    out, _ = model(sample)

                utils.save_image(
                    torch.cat([sample, out], 0),
                    f'/kaggle/working/samples/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png',
                    nrow=sample_size,
                    normalize=True,
                    range=(-1, 1),
                )

                wandb.log({f"{epoch+1}_Samples" : [wandb.Image(img) for img in torch.cat( [sample, out], 0) ]})
                accelerator.save(model.state_dict(), f'checkpoint/vqvae_{str(epoch + 1).zfill(3)}.pt')
                model.train()

            print(f'\n---EPOCH {epoch} CCOMPLETED---\n')

I don’t see how tensor shapes are changing at all - so things should work relatively smoother.

Validation on around 15,000 images is relatively fine I suppose,

 98%|████████████████████████████████████████▎| 116/118 [07:37<00:04,  2.48s/it]

Though I’d expected it to be much, much faster.

But training,

2%|▌                                 | 31/1783 [2:37:49<145:33:55, 299.11s/it]
{'epoch': 1, 'val_mse': 0.07451071590185165, 'val_latent_loss': 0.03670966625213623, 'val_avg_mse': 0.13713628306023537, 'lr': 0.0001}

is pretty slow. Any ideas why computations are being offloaded to CPU while the graph is clearly being built on GPU?

Also,

parser.add_argument('--cpu-run', type=bool, default=False)

my CPU run arg is set to False by default, so at least I am not overriding GPU computation 🤷‍♂️

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:17 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
neel04commented, Feb 24, 2022
Accelerator State: Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda
Use FP16 precision: False

I am currently going to revamp my entire training loop by making the train function train only for a single epoch, using another external function with a loop to interface and run the train and seeing whether that improves anything - because apart from a bug in Accelerate the only thing I can think of is my loop 🤔

0reactions
github-actions[bot]commented, May 24, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Decreasing performance when using Accelerate
The time of training is similar to the time used with Trainer but the performance is much worse, getting a solid 70% accuracy...
Read more >
Time manipulation technique for speeding up reinforcement ...
Abstract: A technique for speeding up reinforcement learning algorithms by using time manipulation is proposed. It is applicable to failure-avoidance ...
Read more >
How to Train Power for Speed and Stronger Lifts | BarBend
Unlike strength training, which is sometimes as simple as adding five pounds to the barbell every time you step into the gym, ...
Read more >
reactive training systems ReactiveTrainingSystems|
It's common meathead knowledge that force equals mass times acceleration and many of you ... It gets you more training volume, which is...
Read more >
The Science of Memory: Top 10 Proven Techniques ... - Zapier
All it takes is trying out new memorization techniques or making key ... The first step to creating a memory is called encoding:...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found