Pure PyTorch vs Lightning is faster with CPU small toy example
See original GitHub issue🐛 Bug
Recently I found that Lightning runs much slower than simple PyTorch code.
Code using Lightning:
import os
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pl_bolts.datasets import DummyDataset
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, LambdaLR
train = DummyDataset((1, 28, 28), (1,), num_samples=100000)
train = DataLoader(train, batch_size=32)
val = DummyDataset((1, 28, 28), (1,))
val = DataLoader(val, batch_size=32)
test = DummyDataset((1, 28, 28), (1,))
test = DataLoader(test, batch_size=32)
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def training_step(self, batch, batch_idx):
# --------------------------
# REPLACE WITH YOUR OWN
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
# --------------------------
def validation_step(self, batch, batch_idx):
# --------------------------
# REPLACE WITH YOUR OWN
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('val_loss', loss)
# --------------------------
def test_step(self, batch, batch_idx):
# --------------------------
# REPLACE WITH YOUR OWN
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('test_loss', loss)
# --------------------------
def configure_optimizers(self):
learning_rate = 1e-3
optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
return optimizer
if __name__ == '__main__':
ae = LitAutoEncoder()
# Initialize a trainer
trainer = pl.Trainer(gpus=None, max_epochs=5,
progress_bar_refresh_rate=1000, log_every_n_steps=1000,
# profiler='simple'
)
# Train the model ⚡
# trainer.fit(ae, train, val)
trainer.fit(ae, train)
Code using just PyTorch:
import os
import math
import time
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pl_bolts.datasets import DummyDataset
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, LambdaLR
train = DummyDataset((1, 28, 28), (1,), num_samples=100000)
train = DataLoader(train, batch_size=32)
val = DummyDataset((1, 28, 28), (1,))
val = DataLoader(val, batch_size=32)
test = DummyDataset((1, 28, 28), (1,))
test = DataLoader(test, batch_size=32)
class PlainModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, batch):
# --------------------------
# REPLACE WITH YOUR OWN
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
if __name__ == '__main__':
model = PlainModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch_num in range(30):
batch_count = 0
start_time = time.time()
for batch in train:
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()
batch_count += 1
end_time = time.time()
print('Epoch {0} speed: {1} in {2} time'.format(epoch_num, batch_count / (end_time - start_time),
end_time - start_time))
The Lightning code runs about 450 it/s on my Mac using CPU vs vanilla PyTorch’s 650 it/s. Vanilla PyTorch code runs about 1.44 times faster than Lightning.
To Reproduce
Use the above code.
Expected behavior
Lightning runs at almost same speed for vanilla PyTorch code.
Environment
- PyTorch Version (e.g., 1.0): 1.5.0
- OS (e.g., Linux): Mac
- How you installed PyTorch (
conda
,pip
, source): conda - Build command you used (if compiling from source): n/a
- Python version: 3.8.3
- CUDA/cuDNN version: n/a
- GPU models and configuration: n/a
- Any other relevant information: n/a -PL version: tried both 1.1.8 and 1.2.1
Additional context
n/a
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (3 by maintainers)
Top Results From Across the Web
Pure PyTorch vs Lightning is faster with CPU small toy example
Recently I found that Lightning runs much slower than simple PyTorch code. Code using Lightning: import os import math import torch from torch ......
Read more >Speed Up Model Training - PyTorch Lightning - Read the Docs
When training on single or multiple GPU machines, Lightning offers a host of advanced optimizations to improve throughput, memory efficiency, and model ...
Read more >PyTorch Lightning vs Ignite: What Are the Differences?
Pytorch has dynamic graphs (Tensorflow has a static graph), which makes Pytorch implementation faster, and adds a pythonic feel to it.
Read more >[D] Are you using PyTorch or TensorFlow going into 2022?
The PyTorch Lightning documentation needs some serious work in some sections, but overall it's a great package.
Read more >Pytorch Lightning vs PyTorch Ignite vs Fast.ai | by William Falcon
Both Lightning and Ignite have very simple interfaces, as most of the work is still done in pure PyTorch by the user. The...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Lightning includes “quite a bit of magic” that adds fixed overhead over PyTorch. As @SeanNaren points out, this overhead is fixed and the scaling behaviour should be very similar, so for non-trivial networks, this should not matter as much. Incidentally, PyTorch has it’s own performance thing going on with
nn.Module
, see https://github.com/pytorch/pytorch/pull/50431 . The symptoms are somewhat similar (but the scaling behaviour arguably worse as larger networks have more modules) and we might learn from their analysis.This issue has been automatically marked as stale because it hasn’t had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!