Poor sampling quality in upscaler Unets
See original GitHub issueI’ve seen good quality results with the upscaling Unets in your DALLE2 repo but have been having trouble getting similar ones with the Imagen ones over the same training period.
After reviewing the code and the Imagen paper I wonder if this is the problem: https://github.com/lucidrains/imagen-pytorch/blob/6a04c908f65e39af55263ce3dcacca943b90fc5a/imagen_pytorch/imagen_pytorch.py#L1584
I think that this should be (1.0 - lowres_sample_noise_level)
? I see that you do your augmentation by sampling at a specific timestep based on the overall number of timesteps, but the default of 0.2
would sample at time 200
- which is actually closer to 0.8
augmentation if you were using a linear scale.
As a workaround I am trying to pass 0.8
into my sample function but I’m not sure this is enough to fully address the issue. Maybe it just takes longer to train since the training operates on a full aug level of 1.0
to 0.0
like the paper does, but I’ll keep the unet training for now.
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:104 (71 by maintainers)
Top GitHub Comments
My biggest improvement in the upscaler was achieved by changing the default unet2
dim = 32
to128
(inspired by Friends don’t let friends train small diffusion models). Training time is much slower and GPU memory usage is much higher, but the upscaler output looks a lot better for me now.This training was seeing discoloration earlier, but it has gone away.