question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Running on CIFAR 10

See original GitHub issue

Hi,

I am trying to train and sample using CIFAR 10 dataset. Below is the code for it.

from keras.datasets import mnist
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
import numpy as np
import tensorflow as tf

model = Unet(
    dim = 16,
    dim_mults = (1, 2, 4)
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = np.asarray(x_train)
x_train = x_train.astype(np.float16)
new_train = torch.from_numpy(np.swapaxes(x_train,1,3))
training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1

trainer = Trainer(
    diffusion,
    new_train,
    train_batch_size = 128,
    train_lr = 1e-4,
    train_num_steps = 70000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.9999,                # exponential moving average decay
    amp = True                        # turn on mixed precision
)
trainer.train()

I modified Trainer such that it could take the dataset. The original Trainer had the below code

self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip)
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

which I modified to

my_dataset = TensorDataset(data) # create your datset
dl = DataLoader(data, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
self.dl = cycle(dl)

The training error in the above case goes to inf after 20k iterations. If I stop before that and sample from it, the images are bunch of random colors. Is there any script which I can use to generate samples from CIFAR10?

Thank You

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:11

github_iconTop GitHub Comments

2reactions
greens007commented, Aug 19, 2022

I have found what caused it. You should set amp to False when training on cifar10. When I did this, the model can converge and generate normal pictures instead of a bunch of random colours.

0reactions
DevJakecommented, Oct 12, 2022

I wonder why that is… I would’ve figured the subtle compression applied by JPG format would cause potential data loss and thus losses in performance. Equally, it might act as a form of very low-level image augmentation by adding in artefacts. Got anything more on your findings? I’d be interested to know how it was determined

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to Develop a CNN From Scratch for CIFAR-10 Photo ...
Running the model in the test harness first prints the classification accuracy on the test dataset. Note: Your results may vary given the ......
Read more >
CIFAR-10 examples - Neural Network Intelligence
CIFAR -10 classification is a common benchmark problem in machine learning. The CIFAR-10 dataset is the collection of images. It is one of...
Read more >
Running a CIFAR 10 image classifier on Windows with pytorch
Running a CIFAR 10 image classifier on Windows with pytorch. At the time of running(7/17), running pytorch requires some effort.
Read more >
CIFAR-10 Tutorial - DeepSpeed
In this tutorial we will be adding DeepSpeed to CIFAR-10 model, which is small image classification model. First we will go over how...
Read more >
CIFAR-10 and CIFAR-100 datasets
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found