Running on CIFAR 10
See original GitHub issueHi,
I am trying to train and sample using CIFAR 10 dataset. Below is the code for it.
from keras.datasets import mnist
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
import numpy as np
import tensorflow as tf
model = Unet(
dim = 16,
dim_mults = (1, 2, 4)
)
diffusion = GaussianDiffusion(
model,
image_size = 32,
timesteps = 1000, # number of steps
loss_type = 'l1' # L1 or L2
)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = np.asarray(x_train)
x_train = x_train.astype(np.float16)
new_train = torch.from_numpy(np.swapaxes(x_train,1,3))
training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1
trainer = Trainer(
diffusion,
new_train,
train_batch_size = 128,
train_lr = 1e-4,
train_num_steps = 70000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.9999, # exponential moving average decay
amp = True # turn on mixed precision
)
trainer.train()
I modified Trainer such that it could take the dataset. The original Trainer had the below code
self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip)
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
which I modified to
my_dataset = TensorDataset(data) # create your datset
dl = DataLoader(data, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
self.dl = cycle(dl)
The training error in the above case goes to inf after 20k iterations. If I stop before that and sample from it, the images are bunch of random colors. Is there any script which I can use to generate samples from CIFAR10?
Thank You
Issue Analytics
- State:
- Created a year ago
- Comments:11
Top Results From Across the Web
How to Develop a CNN From Scratch for CIFAR-10 Photo ...
Running the model in the test harness first prints the classification accuracy on the test dataset. Note: Your results may vary given the ......
Read more >CIFAR-10 examples - Neural Network Intelligence
CIFAR -10 classification is a common benchmark problem in machine learning. The CIFAR-10 dataset is the collection of images. It is one of...
Read more >Running a CIFAR 10 image classifier on Windows with pytorch
Running a CIFAR 10 image classifier on Windows with pytorch. At the time of running(7/17), running pytorch requires some effort.
Read more >CIFAR-10 Tutorial - DeepSpeed
In this tutorial we will be adding DeepSpeed to CIFAR-10 model, which is small image classification model. First we will go over how...
Read more >CIFAR-10 and CIFAR-100 datasets
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I have found what caused it. You should set amp to False when training on cifar10. When I did this, the model can converge and generate normal pictures instead of a bunch of random colours.
I wonder why that is… I would’ve figured the subtle compression applied by JPG format would cause potential data loss and thus losses in performance. Equally, it might act as a form of very low-level image augmentation by adding in artefacts. Got anything more on your findings? I’d be interested to know how it was determined