[Accelerate] Each Training step takes an absurd amount of time
See original GitHub issueTraining a VQ-VAE-2 on Kaggle’s GPU - when I start training on a full epoch, each iteration (i.e each step/pass on a single batch) takes around 5 mins on GPU. GPU usage is negligible, yet it does occupy all VRAM as CUDA should 🤷
This is my training function:-
def train(loader, val_loader, scheduler):
accelerator = Accelerator(fp16=False, cpu=args.cpu_run)
device = accelerator.device
#initializing the model
model = VQVAE(in_channel=3, channel=128, n_res_block=args.res_blocks,
n_res_channel=args.res_channel,
embed_dim=args.embed_dim, n_embed=args.n_embed,
decay=args.decay).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
accelerator.print(summary(model, (args.batch_size, 3, args.size, args.size)))
accelerator.print(model) #vanilla pytorch summary
model, optimizer, loader, val_loader = accelerator.prepare(model, optimizer, loader, val_loader)
loader, val_loader = tqdm(loader), tqdm(val_loader)
criterion = nn.MSELoss()
latent_loss_weight = 0.25
latent_loss_beta_list = torch.linspace(0, latent_loss_weight, 20)
sample_size = 20
mse_sum = 0
mse_n = 0
val_mse_sum, val_mse_n, beta_index = 0, 0, 0 #init beta index
with wandb.init(project=args.wandb_project_name, config=args.__dict__, save_code=True, name=args.run_name, magic=True):
for epoch in range(args.epoch):
#Starting Epoch loops
model.train()
for i, (img, label) in enumerate(loader):
img = img.to(device)
out, latent_loss = model(img)
recon_loss = criterion(out, img)
latent_loss = latent_loss.mean()
beta_index = epoch #for consistency
if beta_index >= 24:
beta_index = latent_loss_beta_list[-1]
loss = recon_loss + latent_loss_beta_list[beta_index] * latent_loss
model.zero_grad(set_to_none=True)
accelerator.backward(loss) #added loss to backprop
if scheduler is not None:
scheduler.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradclip) #grad clipping
optimizer.step()
mse_sum += recon_loss.item() * img.shape[0]
mse_n += img.shape[0]
lr = optimizer.param_groups[0]['lr']
wandb.log({"epoch": epoch+1, "mse": recon_loss.item(),
"latent_loss": latent_loss.item(), "avg_mse": (mse_sum / mse_n),
"lr": lr})
if i % 25 == 0:
accelerator.print({"epoch": epoch+1, "mse": recon_loss.item(),
"latent_loss": latent_loss.item(), "avg_mse": (mse_sum/ mse_n),
"lr": lr})
#Performing Validation and loggign out images
if epoch % 2 == 0: #i % 100 == 0
model.eval()
#--------------VALIDATION------------------
for i, (img, label) in enumerate(val_loader):
img.to(device)
model.to(device)
with torch.no_grad():
out, latent_loss = model(img)
val_recon_loss = criterion(out, img)
val_latent_loss = latent_loss.mean()
val_loss = recon_loss + latent_loss_beta_list[beta_index] * latent_loss
val_mse_sum += recon_loss.item() * img.shape[0]
val_mse_n += img.shape[0]
wandb.log({"epoch": epoch+1, "val_mse": val_recon_loss.item(),
"val_latent_loss": val_latent_loss.item(), "val_avg_mse": (val_mse_sum/ val_mse_n),
"lr": lr})
accelerator.print({"epoch": epoch+1, "val_mse": val_recon_loss.item(),
"val_latent_loss": val_latent_loss.item(), "val_avg_mse": (val_mse_sum/ val_mse_n),
"lr": lr})
model.train()
#Saving the model checkpoints every epoch
if epoch % 10 == 0:
model.eval()
sample = img[:sample_size]
with torch.no_grad():
out, _ = model(sample)
utils.save_image(
torch.cat([sample, out], 0),
f'/kaggle/working/samples/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png',
nrow=sample_size,
normalize=True,
range=(-1, 1),
)
wandb.log({f"{epoch+1}_Samples" : [wandb.Image(img) for img in torch.cat( [sample, out], 0) ]})
accelerator.save(model.state_dict(), f'checkpoint/vqvae_{str(epoch + 1).zfill(3)}.pt')
model.train()
print(f'\n---EPOCH {epoch} CCOMPLETED---\n')
I don’t see how tensor shapes are changing at all - so things should work relatively smoother.
Validation on around 15,000 images is relatively fine I suppose,
98%|████████████████████████████████████████▎| 116/118 [07:37<00:04, 2.48s/it]
Though I’d expected it to be much, much faster.
But training,
2%|▌ | 31/1783 [2:37:49<145:33:55, 299.11s/it]
{'epoch': 1, 'val_mse': 0.07451071590185165, 'val_latent_loss': 0.03670966625213623, 'val_avg_mse': 0.13713628306023537, 'lr': 0.0001}
is pretty slow. Any ideas why computations are being offloaded to CPU while the graph is clearly being built on GPU?
Also,
parser.add_argument('--cpu-run', type=bool, default=False)
my CPU run arg is set to False
by default, so at least I am not overriding GPU computation 🤷♂️
Issue Analytics
- State:
- Created 2 years ago
- Comments:17 (7 by maintainers)
Top Results From Across the Web
Decreasing performance when using Accelerate
The time of training is similar to the time used with Trainer but the performance is much worse, getting a solid 70% accuracy...
Read more >Time manipulation technique for speeding up reinforcement ...
Abstract: A technique for speeding up reinforcement learning algorithms by using time manipulation is proposed. It is applicable to failure-avoidance ...
Read more >How to Train Power for Speed and Stronger Lifts | BarBend
Unlike strength training, which is sometimes as simple as adding five pounds to the barbell every time you step into the gym, ...
Read more >reactive training systems ReactiveTrainingSystems|
It's common meathead knowledge that force equals mass times acceleration and many of you ... It gets you more training volume, which is...
Read more >The Science of Memory: Top 10 Proven Techniques ... - Zapier
All it takes is trying out new memorization techniques or making key ... The first step to creating a memory is called encoding:...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I am currently going to revamp my entire training loop by making the
train
function train only for a single epoch, using another external function with a loop to interface and run thetrain
and seeing whether that improves anything - because apart from a bug inAccelerate
the only thing I can think of is my loop 🤔This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.