[Community] Testing Stable Diffusion is hard đ„”
See original GitHub issueItâs really difficult to test stable diffusion due to the following:
-
- Continous output: Diffusion models take float values as input and output float values. This is different from NLP models which tend to take int64 as inputs and int64 as outputs.
-
- Output dimensions are huge. If an image has a output size of
(1, 512, 512, 3)
this means that there are 512 * 512 * 3 ~ 800,000 values that need to be within a given range. Say if you want to test for a max difference of(pred - ref).abs() < 1e-3
we have roughly a million values where this has to hold true. This is quite different in NLP where we rather test things like text generation or final logit layers which usually arenât bigger then a dozen or so tensors of size 768 or 1024.
- Output dimensions are huge. If an image has a output size of
-
- Error propagation: We cannot simple test one forward pass for stable diffusion because in practice people use 50 forward passes. Error propagation becomes a real problem in this case. This again is different from say generation in NLP because in generation at every generation step errors can be somewhat âsmoothedâ out since a âargmaxâ of âsoftmaxâ operation is used after each step
-
- Composite systems: Stable Diffusion has three main components for inference: A Unet, a scheduler and a VAE decoder. The UNet and Scheduler are very entangled during the forward pass. Just because we know the forward pass of both the scheduler and unet work independently, it doesnât mean that using them together works.
=> Therefore, we need to do full integration tests, meaning we need to make sure that the output of a full denoising process stays within a given error range. At the moment, weâre having quite some problems though to get full reproducible of results on different GPUs, CUDA versions etc⊠(especially for FP16).
That being said, it is extremely important to test stable diffusion to avoid issues like this in the future: https://github.com/huggingface/diffusers/issues/902 whereas we should still be able to improve speed with PRs like this: https://github.com/huggingface/diffusers/pull/371
At the moment, weâre running multiple integration tests for all 50 diffusion steps every time a PR is merged to master, see:
- https://github.com/huggingface/diffusers/blob/31af4d17e81308887ff63080d49fba644e6c3963/tests/pipelines/stable_diffusion/test_stable_diffusion.py#L518
- https://github.com/huggingface/diffusers/blob/31af4d17e81308887ff63080d49fba644e6c3963/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py#L465
- https://github.com/huggingface/diffusers/blob/31af4d17e81308887ff63080d49fba644e6c3963/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py#L262
Nevertheless, the tests werenât sufficient to detect: https://github.com/huggingface/diffusers/issues/902
Testing Puzzle đ§©: How can we find the best trade-off between fast & in-expensive test and best possible test coverage taking into account the above points?
We already looked quite a bit into: https://pytorch.org/docs/stable/notes/randomness.html
Issue Analytics
- State:
- Created a year ago
- Reactions:2
- Comments:14 (12 by maintainers)
Top GitHub Comments
Thatâs a very good question - we also ran into this problem with @anton-l đ
In short, for public PRs we donât test any models that require the auth_token verification. The reason here is that Public PRs cannot have access to the GitHub secret token of our github repo which means that the PRs fail (please correct me if Iâm wrong @anton-l)
when we merge to âmainâ, we always have access to our secret GitHub token and then can run the tests on the models.
My 2c on this, ideally you can have some fast unit tests that will cover some baseline correctness issues but for the vast majority of issues youâre highlighting you need time-consuming runs so thereâs 2 things that help dramatically in my experience.
Finally, a lot of these problems become easier to handle if the models are faster either by changing the models or by using training compilers so focus on speed and test time will be crucial to make this process sane.