question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Pure PyTorch vs Lightning is faster with CPU small toy example

See original GitHub issue

🐛 Bug

Recently I found that Lightning runs much slower than simple PyTorch code.

Code using Lightning:

import os
import math

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pl_bolts.datasets import DummyDataset
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, LambdaLR

train = DummyDataset((1, 28, 28), (1,), num_samples=100000)
train = DataLoader(train, batch_size=32)

val = DummyDataset((1, 28, 28), (1,))
val = DataLoader(val, batch_size=32)

test = DummyDataset((1, 28, 28), (1,))
test = DataLoader(test, batch_size=32)


class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def training_step(self, batch, batch_idx):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss
        # --------------------------

    def validation_step(self, batch, batch_idx):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss)
        # --------------------------

    def test_step(self, batch, batch_idx):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('test_loss', loss)
        # --------------------------

    def configure_optimizers(self):
        learning_rate = 1e-3
        optimizer = torch.optim.SGD(self.parameters(), lr=learning_rate)
        return optimizer


if __name__ == '__main__':
    ae = LitAutoEncoder()

    # Initialize a trainer
    trainer = pl.Trainer(gpus=None, max_epochs=5,
                         progress_bar_refresh_rate=1000, log_every_n_steps=1000,
                         # profiler='simple'
                         )

    # Train the model ⚡
    # trainer.fit(ae, train, val)
    trainer.fit(ae, train)

Code using just PyTorch:

import os
import math
import time

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pl_bolts.datasets import DummyDataset
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, LambdaLR

train = DummyDataset((1, 28, 28), (1,), num_samples=100000)
train = DataLoader(train, batch_size=32)

val = DummyDataset((1, 28, 28), (1,))
val = DataLoader(val, batch_size=32)

test = DummyDataset((1, 28, 28), (1,))
test = DataLoader(test, batch_size=32)


class PlainModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, batch):
        # --------------------------
        # REPLACE WITH YOUR OWN
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss


if __name__ == '__main__':
    model = PlainModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

    for epoch_num in range(30):
        batch_count = 0
        start_time = time.time()
        for batch in train:
            optimizer.zero_grad()
            loss = model(batch)
            loss.backward()
            optimizer.step()
            batch_count += 1
        end_time = time.time()
        print('Epoch {0} speed: {1} in {2} time'.format(epoch_num, batch_count / (end_time - start_time),
                                                        end_time - start_time))

The Lightning code runs about 450 it/s on my Mac using CPU vs vanilla PyTorch’s 650 it/s. Vanilla PyTorch code runs about 1.44 times faster than Lightning.

To Reproduce

Use the above code.

Expected behavior

Lightning runs at almost same speed for vanilla PyTorch code.

Environment

  • PyTorch Version (e.g., 1.0): 1.5.0
  • OS (e.g., Linux): Mac
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): n/a
  • Python version: 3.8.3
  • CUDA/cuDNN version: n/a
  • GPU models and configuration: n/a
  • Any other relevant information: n/a -PL version: tried both 1.1.8 and 1.2.1

Additional context

n/a

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
t-vicommented, Mar 2, 2021

Lightning includes “quite a bit of magic” that adds fixed overhead over PyTorch. As @SeanNaren points out, this overhead is fixed and the scaling behaviour should be very similar, so for non-trivial networks, this should not matter as much. Incidentally, PyTorch has it’s own performance thing going on with nn.Module, see https://github.com/pytorch/pytorch/pull/50431 . The symptoms are somewhat similar (but the scaling behaviour arguably worse as larger networks have more modules) and we might learn from their analysis.

0reactions
stale[bot]commented, Apr 3, 2021

This issue has been automatically marked as stale because it hasn’t had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Pure PyTorch vs Lightning is faster with CPU small toy example
Recently I found that Lightning runs much slower than simple PyTorch code. Code using Lightning: import os import math import torch from torch ......
Read more >
Speed Up Model Training - PyTorch Lightning - Read the Docs
When training on single or multiple GPU machines, Lightning offers a host of advanced optimizations to improve throughput, memory efficiency, and model ...
Read more >
PyTorch Lightning vs Ignite: What Are the Differences?
Pytorch has dynamic graphs (Tensorflow has a static graph), which makes Pytorch implementation faster, and adds a pythonic feel to it.
Read more >
[D] Are you using PyTorch or TensorFlow going into 2022?
The PyTorch Lightning documentation needs some serious work in some sections, but overall it's a great package.
Read more >
Pytorch Lightning vs PyTorch Ignite vs Fast.ai | by William Falcon
Both Lightning and Ignite have very simple interfaces, as most of the work is still done in pure PyTorch by the user. The...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found