Dreambooth broken, possibly because of ADAM optimizer, possibly more.
See original GitHub issueI think Huggingface’s Dreambooth is the only popular SD implementation that also uses Prior Preservation Loss, so I’ve been motivated to get it working, but the results have been terrible, and the entire model degrades, regardless of: # timesteps, learning rate, PPL turned on/off, # instance samples, # class regularization samples, etc. I’ve read the paper, and found that they actually unfreeze everything including the text embedder (and VAE? I’m not sure so I leave it frozen), so I implemented textual inversion within the dreambooth example (new token, unfreeze a single row of the embedder), which improves results considerably, but the whole model still degrades no matter what.
Someone smarter than me can confirm, but I think the culprit is ADAM:
My hypothesis is that since ADAM tries to drag all weights of unet etc. to 0, it ruins parts of the model that aren’t concurrently being trained during the finetuning.
I’ve tested with weight_decay
set to 0, and results seem considerably better, but I think the entire model is still degrading. I’m trying SGD next, so, fingers crossed, but there may still be some dragon lurking in the depths even despite removing ADAM.
A point of reference on this journey is the JoePenna “Dreambooth” library which doesn’t implement PPL, and yet preserves priors much much better than this example, not to mention it learns the instance better, and is far more editable, and preserves out-of-class rather well. I expect more from this huggingface dreambooth example, and I’m trying to find why it’s not delivering.
Any thoughts or guidance?
EDIT1A: SGD didn’t learn the instance at 1000 steps + lr=5e-5, but it definitely preserved the priors way better (upon visual inspection. The loss really doesn’t decrease much in any of my inversion/dreambooth experiments).
EDIT1B: Another test failed to learn using SGD at 1500 steps + lr=1e-3 + momentum=0.9. It might be trying to learn, but, not much. Priors were nicely preserved though still.
EDIT1C: 1500 * lr=5e2 learned great, was editable, didn’t destroy other priors!!!
EDIT2: JoePenna seems to use AdamW, so I’m not sure what’s up anymore, but I’m still getting quite poor results training with this library’s (huggingface’s) DB example.
Issue Analytics
- State:
- Created a year ago
- Reactions:3
- Comments:40 (9 by maintainers)
Top GitHub Comments
Hi everyone! Sorry to be so late here.
We ran a lot of experiments with the script to see if there’s any issues or if it’s broken. Turns out, we need to carefully pick hyperparameters like LR and training steps to get good results with dreambooth.
Also, the main reason the results with this script were not as good as Compvis forks, is that the
text_encoder
is being trained in those forks, and it makes a big difference in quality especially on faces.We compiled all our experiments in this report, and also added an option to train the
text_encoder
in the script which can be enabled by passing the--train_text_encoder
argument.Note that, if we train the
text_encoder
the training won’t fit on 16GB GPU anymore, it will need at least 24GB VRAM. It should still be possible to do it 16GB using deepspeed, but will be slower.Please take a look at the report, hope you find it useful.
Think @patil-suraj will very soon release a nice update of dreambooth 😃