More "OpenAI Blog Post" Training | Depth 32 | Heads 8 | LR 5e-4
See original GitHub issueEdit: Moved to discussions: https://github.com/lucidrains/DALLE-pytorch/discussions/106
Hey, all. Some of you might know I’m practicing and learning about machine learning with dalle-pytorch and a dataset consisting of the images OpenAI presented in the DALLE blog post. I honestly dont have the money to train this whole dataset,
edit: this is no longer true. Using the 1024 VQGAN from the “Taming Transformers” research, it’s now quite possible to train a full dataset of 1,000,000 image-text pairs and i’m doing just that. I hope to have it finished in about a week. I assume someone else will release a dalle-pytorch trained properly on COCO and other image sets before then, but if they dont, check here for updates.
Anway, it ran for ~36000 steps. As you can see it…still really likes mannequins. I’m considering removing them from the dataset. But also, you’ll notice that the network has actually got a decent idea of the sort of general colors that belong in types of prompts.
Some Samples from Near the End of Training
Every Text-Image Reconstruction
Deliverables (my train_dalle.py)
https://gist.github.com/afiaka87/850fb3cc48edde8a7ed4cb1ce53b6bd2
This has some code in it that actually manages to deal with truncated images via Try Catch. Apparently detecting a corrupted PNG is harder than P vs NP. PIL’s imverify()
function doesnt catch all of them. Python’s built in imghdr
library doesn’t catch all of them either. So you just sort of catch OSError and return an item further along. Works well enough.
Parameters
SHUFFLE = True
EPOCHS = 28 # This wound up being less than a single epoch, of course.
BATCH_SIZE = 16
LEARNING_RATE = 0.0005 # I found this learning rate to be more suitable than 0.0003 in my hyperparameter sweep post
GRAD_CLIP_NORM = 0.5
DEPTH = 32
HEADS = 8
MODEL_DIM = 512
TEXT_SEQ_LEN = 256
DIM_HEAD = 64
REVERSIBLE = True,
ATTN_TYPES = ('full')
Dataset Description
https://github.com/lucidrains/DALLE-pytorch/issues/61#issuecomment-796663342
Just for more info on the dataset itself, it is roughly 1,100,000 256x256 image-text pairs that were generated by OpenAI’s DALL-E. They presented roughly ~30k unique text prompts of which they posted the top 32 of 512 generations on https://openai.com/blog/dall-e/. Many images were corrupt, and not every prompt has a full 32 examples, but the total number of images winds up being about 1.1 million. If you look at many of the examples on that page, you’ll see that DALL-E (in that form at least), can and will make mistakes. These mistakes are also in this dataset. Anyway I’m just messing around having fun training and what not. This is definitely not going to produce a good model or anything.
There are also a large number of images in the dataset which are intended to be used with the “mask” feature. I don’t know if that’s possible yet in DALLE-pytorch though. Anyway, that can’t be helping much.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:17
- Comments:31 (28 by maintainers)
Top GitHub Comments
@robvanvolt Here’s some early results from training on that dataset by the way. I think we should definitely clean it up with the info from OpenAI. https://wandb.ai/afiaka87/OpenImagesV6/reports/dalle-pytorch-OpenImagesV6-With-Localized-Annotations---Vmlldzo1MzgyMTU
After about ~15k iters, I stopped training, added the COCO2018 dataset and resumed from there for another ~6K steps. https://wandb.ai/afiaka87/OpenImagesV6/reports/OpenImagesV6-COCO--Vmlldzo1MzgyNTI
@lucidrains @Jinglei5
Ha I do that as well. It is insane to me the number of things that just straight up break when you’re dealing with lots of files.
It’s all good though, I managed to figure it out:
find . -type f -name "*.jpg" | parallel mogrify -resize 256x {}