question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

amp + checkpoint loading = problems

See original GitHub issue

Hi, as you know I have been experimenting with amp for a while now. Today I stumbled upon very unexpected behavior. My FP16 models (trained with amp) do just as well than the FP32 models by themselves. But I usually also ensemble my models by doing something like this:

results = []
for c in checkpoints:
    network.load(checkpoint)
    results.append(network(data))

Interestingly, the performance drops quite a bit if I am doing that with amp enabled. To illustrate this, I created a minimalistic example with mnist:

from copy import deepcopy
import torch
import matplotlib
matplotlib.use("agg")
from torch.backends import cudnn
from apex import amp
import argparse
from torch import cuda
from torch import nn
from urllib import request
import gzip
import pickle
import os
import numpy as np


def load(mnist_file):
    init()
    with open(mnist_file, 'rb') as f:
        mnist = pickle.load(f)
    data_tr = mnist["training_images"].reshape(60000, 1, 28, 28)
    data_te = mnist["test_images"].reshape(10000, 1, 28, 28)
    return data_tr, mnist["training_labels"], data_te, mnist["test_labels"]


filename = [
["training_images","train-images-idx3-ubyte.gz"],
["test_images","t10k-images-idx3-ubyte.gz"],
["training_labels","train-labels-idx1-ubyte.gz"],
["test_labels","t10k-labels-idx1-ubyte.gz"]
]


def download_mnist():
    base_url = "http://yann.lecun.com/exdb/mnist/"
    for name in filename:
        print("Downloading "+name[1]+"...")
        request.urlretrieve(base_url+name[1], name[1])
    print("Download complete.")


def save_mnist():
    mnist = {}
    for name in filename[:2]:
        with gzip.open(name[1], 'rb') as f:
            mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28)
    for name in filename[-2:]:
        with gzip.open(name[1], 'rb') as f:
            mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
    with open("mnist.pkl", 'wb') as f:
        pickle.dump(mnist,f)
    print("Save complete.")


def init():
    if not os.path.isfile("mnist.pkl"):
        download_mnist()
        save_mnist()


def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
    return initial_lr * (1 - epoch / max_epochs)**exponent


class GlobalAveragePool(nn.Module):
    def forward(self, x):
        axes = range(2, len(x.shape))
        for a in axes[::-1]:
            x = x.mean(a, keepdim=False)
        return x


def get_default_network_config():
    """
    returns a dictionary that contains pointers to conv, nonlin and norm ops and the default kwargs I like to use
    :return:
    """
    props = {}
    props['conv_op'] = nn.Conv2d
    props['conv_op_kwargs'] = {'stride': 1, 'dilation': 1, 'bias': True} # kernel size will be set by network!
    props['nonlin'] = nn.LeakyReLU
    props['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True}
    props['norm_op'] = nn.BatchNorm2d
    props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
    props['dropout_op'] = nn.Dropout2d
    props['dropout_op_kwargs'] = {'p': 0.0, 'inplace': True}
    return props


class ConvDropoutNormReLU(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, network_props):
        """
        if network_props['dropout_op'] is None then no dropout
        if network_props['norm_op'] is None then no norm
        :param input_channels:
        :param output_channels:
        :param kernel_size:
        :param network_props:
        """
        super(ConvDropoutNormReLU, self).__init__()

        network_props = deepcopy(network_props)  # network_props is a dict and mutable, so we deepcopy to be safe.

        self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size,
                                             padding=[(i - 1) // 2 for i in kernel_size],
                                             **network_props['conv_op_kwargs'])

        # maybe dropout
        if network_props['dropout_op'] is not None:
            self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs'])
        else:
            self.do = lambda x: x

        if network_props['norm_op'] is not None:
            self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs'])
        else:
            self.norm = lambda x: x

        self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs'])

        self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin)

    def forward(self, x):
        return self.all(x)


class StackedConvLayers(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None):
        """
        if network_props['dropout_op'] is None then no dropout
        if network_props['norm_op'] is None then no norm
        :param input_channels:
        :param output_channels:
        :param kernel_size:
        :param network_props:
        """
        super(StackedConvLayers, self).__init__()

        network_props = deepcopy(network_props)  # network_props is a dict and mutable, so we deepcopy to be safe.
        network_props_first = deepcopy(network_props)

        if first_stride is not None:
            network_props_first['conv_op_kwargs']['stride'] = first_stride

        self.convs = nn.Sequential(
            ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first),
            *[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in range(num_convs - 1)]
        )

    def forward(self, x):
        return self.convs(x)


class SimpleNetwork(nn.Module):
    def __init__(self, props=None):
        super(SimpleNetwork, self).__init__()
        if props is None:
            props = get_default_network_config()
        self.stage1 = StackedConvLayers(1, 16, (3, 3), props, 2, 1)
        self.stage2 = StackedConvLayers(16, 32, (3, 3), props, 2, 2)
        self.stage3 = StackedConvLayers(32, 64, (3, 3), props, 3, 2)
        self.stage4 = StackedConvLayers(64, 128, (3, 3), props, 3, 2)
        self.pool = GlobalAveragePool()
        self.fc = nn.Linear(128, 10, False)

    def forward(self, x):
        return self.fc(self.pool(self.stage4(self.stage3(self.stage2(self.stage1(x))))))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, required=False, default=None)
    parser.add_argument("--test_only", action="store_true", default=False)
    parser.add_argument("-s", help="output filename for trained model")
    parser.add_argument("-test_fnames", required=False, nargs='+')

    args = parser.parse_args()
    seed = args.seed
    test_only = args.test_only

    # seeding
    np.random.seed(seed)
    cuda.manual_seed(np.random.randint(10000))
    cuda.manual_seed_all(np.random.randint(10000))
    cudnn.deterministic = True
    cudnn.benchmark = False

    amp_handle = amp.init()

    data_tr, target_tr, data_te, target_te = load("mnist.pkl")

    data_tr = torch.from_numpy(data_tr).float().cuda()
    target_tr = torch.from_numpy(target_tr).long().cuda()
    data_te = torch.from_numpy(data_te).float().cuda()
    target_te = torch.from_numpy(target_te).long().cuda()

    network = SimpleNetwork().cuda()

    batch_size = 512

    if not test_only:
        optimizer = torch.optim.Adam(network.parameters(), 1e-3, amsgrad=True, weight_decay=1e-5)

        epochs = 30

        loss = torch.nn.CrossEntropyLoss()

        network.train()
        for epoch in range(epochs):
            print(epoch)
            optimizer.param_groups[0]['lr'] = poly_lr(epoch, epochs, 1e-3, 0.9)

            for _ in range(60000 // batch_size):
                optimizer.zero_grad()
                idxs = np.random.choice(60000, batch_size)
                data = data_tr[idxs]
                target = target_tr[idxs]

                out = network(data)

                l = loss(out, target)

                with amp_handle.scale_loss(l, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step()

        torch.save(network.state_dict(), args.s)

        with torch.no_grad():
            network.eval()
            out = network(data_te)

            _, amax = out.max(dim=1)
            acc = (amax == target_te).float().mean()
            print("accuracy on test: ", acc)
    else:
        if not isinstance(args.test_fnames, list):
            args.test_fnames = [args.test_fnames]

        for f in args.test_fnames:
            network.load_state_dict(torch.load(f, map_location=torch.device('cuda', torch.cuda.current_device())))

            with torch.no_grad():
                network.eval()
                out = network(data_te)

                _, amax = out.max(dim=1)
                acc = (amax == target_te).float().mean()
                print("file", f, "accuracy on test: ", acc)

I just hacked this together, so please ignore any potential ugliness in the code.

Here is how you can reproduce the problem: First, train the network several times and save to different output files:

python train_mnist.py --seed 1 -s mnist_seed1.model

accuracy on test: tensor(0.9959, device=‘cuda:0’)

python train_mnist.py --seed 2 -s mnist_seed2.model

accuracy on test: tensor(0.9955, device=‘cuda:0’)

python train_mnist.py --seed 3 -s mnist_seed3.model

accuracy on test: tensor(0.9949, device=‘cuda:0’)

Now that you have the trained models, you can run the testing by passing the filenames to the script like this: python train_mnist.py --test_only -test_fnames mnist_seed1.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device=‘cuda:0’)

python train_mnist.py --test_only -test_fnames mnist_seed2.model

file mnist_seed2.model accuracy on test: tensor(0.9955, device=‘cuda:0’)

python train_mnist.py --test_only -test_fnames mnist_seed3.model

file mnist_seed3.model accuracy on test: tensor(0.9949, device=‘cuda:0’)

The script also supports giving it several model checkpoints at once and it will test all of them one after the other. Although I am not ensembling here, this is the same procedure that I do in my ensembling code and the same issue appears here as well: python train_mnist.py --test_only -test_fnames mnist_seed1.model mnist_seed2.model mnist_seed3.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device=‘cuda:0’) file mnist_seed2.model accuracy on test: tensor(0.1135, device=‘cuda:0’) file mnist_seed3.model accuracy on test: tensor(0.1029, device=‘cuda:0’)

If you look into the script (line 240+), it is doing nothing different than before, except loading new checkpoints with network.load_state_dict between test set predictions. We are seeing a big drop in performance from the second checkpoint onwards.

To demonstrate that this is not a problem with the files themselves, I ran it in a different order with the same result: python train_mnist.py --test_only -test_fnames mnist_seed3.model mnist_seed1.model mnist_seed2.model

file mnist_seed3.model accuracy on test: tensor(0.9949, device=‘cuda:0’) file mnist_seed1.model accuracy on test: tensor(0.1036, device=‘cuda:0’) file mnist_seed2.model accuracy on test: tensor(0.1010, device=‘cuda:0’)

I can fix this issue in this particular script by not initializing amp when I am running just the testing (replace amp_handle = amp.init() with

    if not test_only:
        amp_handle = amp.init()

). After replacing that, testing multiple checkpoints runs nicely:

python1 train_mnist.py --test_only -test_fnames mnist_seed1.model mnist_seed2.model mnist_seed3.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device=‘cuda:0’) file mnist_seed2.model accuracy on test: tensor(0.9954, device=‘cuda:0’) file mnist_seed3.model accuracy on test: tensor(0.9949, device=‘cuda:0’)

I am not sure what is going on here, but I think this it would be rather important to understand what is going on. It took me a good 3 hours to finally figure out what was causing my severe performance regression today. Do you have any idea how this issue could be solved? I need to be able to load checkpoints during and after my trainings and rely on this to work 😃

Best, Fabian

Issue Analytics

  • State:open
  • Created 5 years ago
  • Reactions:2
  • Comments:41 (7 by maintainers)

github_iconTop GitHub Comments

5reactions
hadaev8commented, Sep 30, 2019

So, i load model, then amp state dick and still have problems with loss spikes, any ideas?

3reactions
npmhungcommented, Apr 16, 2019

Is there any update yet? I’m running into the same problem and cannot figure out what is causing this. I tried to restore the model and continued to train it but it seemed that I was training from scratch not from checkpoint.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Automatic Mixed Precision — PyTorch Tutorials 1.12.1+cu102 ...
If a checkpoint was created from a run without Amp, and you want to resume training with Amp, load model and optimizer states...
Read more >
PyTorch loading GradScaler from checkpoint - Stack Overflow
I am saving my model, optimizer, scheduler, and scaler in a general checkpoint. Now when I load ...
Read more >
Trainer — PyTorch Lightning 1.8.5.post0 documentation
If None and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous trainer.fit call will...
Read more >
Trainer - Hugging Face
If this pytorch issue gets resolved it will be possible to change this class ... If a bool and equals True , load...
Read more >
Clara Train FAQ - NVIDIA Documentation Center
What can I do if AMP doesn't show me any difference in the model memory footprint? 12. How can I save and load...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found