`test_ddpm_ddim_equality` fails if manual seed changed to 4
See original GitHub issueDescribe the bug
Test PipelineTesterMixin::test_ddpm_ddim_equality
in test_pipelines.py
fails if generator
is set to torch.manual_seed(4)
instead of torch.manual_seed(0)
.
So the fact that the test passes for seed equal 0 is most likely accidental, and the difference be fluctuate above or below 0.1
, depending on the seed. This is probably also why PipelineTesterMixin::test_ddpm_ddim_equality_batched
was breaking for batch_size != 1
- marked as “needs investigation” by @anton-l
Why this happens
It seems that DDIM and DDPM with equal number of inference steps are only equivalent when use_clipped_model_output=True
. With use_clipped_model_output=False
even one step of the schedulers produces different results; if, however, that variable is switched to True
, the difference shrinks by a couple of orders of magnitude - the remainder is probably numerical error.
To check this, I modified the code to pass use_clipped_model_output=True
into DDIM scheduler step, and saw that both test_ddpm_ddim_equality
and test_ddpm_ddim_equality_batched
pass independently of the seed and batch size
Solutions
So one can do one of the following:
- take
use_clipped_model_output
as an argument toDDIMPipeline.__call__
and pass it down toDDIMScheduler.step
. The tests would then callDDIMPipeline
withuse_clipped_model_output=True
- call
DDIMScheduler.step
withuse_clipped_model_output=True
fromDDIMPipeline.__call__
- change default value of
use_clipped_model_output
inDDIMScheduler.step
toTrue
The first option would give users maximum flexibility - they can set use_clipped_model_output
to True
to have alignment with DDPM, or to False
to follow the original implementation
Let me know what you think!
Reproduction
def test_ddpm_ddim_equality(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
ddim_scheduler = DDIMScheduler(tensor_format="pt")
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(4)
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(4)
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
RUN_SLOW=1 pytest -vvv -s tests/test_pipelines.py::PipelineTesterMixin::test_ddpm_ddim_equality
Logs
============================= test session starts ==============================
platform linux -- Python 3.7.14, pytest-7.1.3, pluggy-1.0.0 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /content/diffusers
plugins: forked-1.4.0, xdist-2.5.0, typeguard-2.7.1
collected 1 item
100% 1000/1000 [00:19<00:00, 51.25it/s]
100% 1000/1000 [00:18<00:00, 54.93it/s]
FAILED
=================================== FAILURES ===================================
_________________ PipelineTesterMixin.test_ddpm_ddim_equality __________________
self = <tests.test_pipelines.PipelineTesterMixin testMethod=test_ddpm_ddim_equality>
@slow
def test_ddpm_ddim_equality(self):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler(tensor_format="pt")
ddim_scheduler = DDIMScheduler(tensor_format="pt")
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(4)
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(4)
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
# the values aren't exactly equal, but the images look the same visually
> assert np.abs(ddpm_image - ddim_image).max() < 1e-1
E AssertionError: assert 0.43824133 < 0.1
E + where 0.43824133 = <built-in method max of numpy.ndarray object at 0x7fb3e40a6cf0>()
E + where <built-in method max of numpy.ndarray object at 0x7fb3e40a6cf0> = array([[[[0.04558021, 0.03272051, 0.00290424],\n [0.03015682, 0.01549941, 0.01387468],\n [0.01811638, 0.00187418, 0.02520341],\n ...,\n [0.01568803, 0.02053806, 0.02523842],\n [0.01075929, 0.01497501, 0.01942393],\n [0.00978824, 0.01409706, 0.0175176 ]],\n\n [[0.05014643, 0.03394508, 0.00047511],\n [0.02992007, 0.01685625, 0.01317757],\n [0.02194265, 0.01309347, 0.01434043],\n ...,\n [0.0106658 , 0.01527333, 0.02131334],\n [0.00817642, 0.01201853, 0.01769397],\n [0.00955456, 0.01317784, 0.01738232]],\n\n [[0.01120478, 0.00307593, 0.0346739 ],\n [0.00118643, 0.00567415, 0.03347611],\n [0.01370427, 0.00763524, 0.03167981],\n ...,\n [0.00381142, 0.00819463, 0.01563641],\n [0.00532115, 0.00892332, 0.01543626],\n [0.0115537 , 0.01504925, 0.02020285]],\n\n ...,\n\n [[0.08257291, 0.11392984, 0.09909356],\n [0.05928966, 0.08817127, 0.07422128],\n [0.05476227, 0.08412179, 0.07177079],\n ...,\n [0.01535362, 0.03447109, 0.01354346],\n [0.00603926, 0.02593628, 0.00519031],\n [0.00404036, 0.01757577, 0.00456756]],\n\n [[0.0805698 , 0.11143386, 0.09731463],\n [0.05878767, 0.08807385, 0.07573456],\n [0.04590842, 0.07566378, 0.06472993],\n ...,\n [0.01950353, 0.03736839, 0.01339555],\n [0.02005506, 0.03832331, 0.01469368],\n [0.01592958, 0.03648573, 0.01218173]],\n\n [[0.01912883, 0.04568186, 0.03250527],\n [0.00111482, 0.02680618, 0.01622108],\n [0.00094694, 0.0268881 , 0.01698533],\n ...,\n [0.02942401, 0.04527366, 0.02022967],\n [0.03194857, 0.04855642, 0.02307808],\n [0.03390098, 0.053085 , 0.02657506]]]], dtype=float32).max
E + where array([[[[0.04558021, 0.03272051, 0.00290424],\n [0.03015682, 0.01549941, 0.01387468],\n [0.01811638, 0.00187418, 0.02520341],\n ...,\n [0.01568803, 0.02053806, 0.02523842],\n [0.01075929, 0.01497501, 0.01942393],\n [0.00978824, 0.01409706, 0.0175176 ]],\n\n [[0.05014643, 0.03394508, 0.00047511],\n [0.02992007, 0.01685625, 0.01317757],\n [0.02194265, 0.01309347, 0.01434043],\n ...,\n [0.0106658 , 0.01527333, 0.02131334],\n [0.00817642, 0.01201853, 0.01769397],\n [0.00955456, 0.01317784, 0.01738232]],\n\n [[0.01120478, 0.00307593, 0.0346739 ],\n [0.00118643, 0.00567415, 0.03347611],\n [0.01370427, 0.00763524, 0.03167981],\n ...,\n [0.00381142, 0.00819463, 0.01563641],\n [0.00532115, 0.00892332, 0.01543626],\n [0.0115537 , 0.01504925, 0.02020285]],\n\n ...,\n\n [[0.08257291, 0.11392984, 0.09909356],\n [0.05928966, 0.08817127, 0.07422128],\n [0.05476227, 0.08412179, 0.07177079],\n ...,\n [0.01535362, 0.03447109, 0.01354346],\n [0.00603926, 0.02593628, 0.00519031],\n [0.00404036, 0.01757577, 0.00456756]],\n\n [[0.0805698 , 0.11143386, 0.09731463],\n [0.05878767, 0.08807385, 0.07573456],\n [0.04590842, 0.07566378, 0.06472993],\n ...,\n [0.01950353, 0.03736839, 0.01339555],\n [0.02005506, 0.03832331, 0.01469368],\n [0.01592958, 0.03648573, 0.01218173]],\n\n [[0.01912883, 0.04568186, 0.03250527],\n [0.00111482, 0.02680618, 0.01622108],\n [0.00094694, 0.0268881 , 0.01698533],\n ...,\n [0.02942401, 0.04527366, 0.02022967],\n [0.03194857, 0.04855642, 0.02307808],\n [0.03390098, 0.053085 , 0.02657506]]]], dtype=float32) = <ufunc 'absolute'>((array([[[[0.40101475, 0.46291834, 0.4121176 ],\n [0.34825665, 0.4004196 , 0.35795364],\n [0.2726246 , 0.30890313, 0.28333575],\n ...,\n [0.03085664, 0.03181294, 0.02819133],\n [0.02086985, 0.02229974, 0.0206081 ],\n [0.00369063, 0.00407588, 0.00600827]],\n\n [[0.22320473, 0.26596433, 0.23987389],\n [0.26899457, 0.3037504 , 0.27794272],\n [0.2812881 , 0.3059944 , 0.2841526 ],\n ...,\n [0.04489943, 0.04644141, 0.04107878],\n [0.02853832, 0.03078747, 0.02764395],\n [0.00990564, 0.01102397, 0.01143444]],\n\n [[0.07842988, 0.10028672, 0.0905675 ],\n [0.0958043 , 0.11334005, 0.10073435],\n [0.10045779, 0.11351112, 0.09573883],\n ...,\n [0.05968505, 0.0629513 , 0.05531797],\n [0.03551909, 0.03880513, 0.0335691 ],\n [0.01331115, 0.01551712, 0.01374137]],\n\n ...,\n\n [[0.20222887, 0.23634496, 0.16945627],\n [0.23803082, 0.27467984, 0.20663318],\n [0.25242266, 0.29079053, 0.2221111 ],\n ...,\n [0.20742086, 0.25217715, 0.21699208],\n [0.23207286, 0.27742156, 0.2419188 ],\n [0.27641273, 0.32188419, 0.2889324 ]],\n\n [[0.2304188 , 0.2675364 , 0.2034539 ],\n [0.26692167, 0.30683082, 0.23949957],\n [0.29059297, 0.33213654, 0.2644921 ],\n ...,\n [0.2891584 , 0.3423161 , 0.31133837],\n [0.2979223 , 0.35372874, 0.3214634 ],\n [0.31305268, 0.36887997, 0.33849669]],\n\n [[0.3601327 , 0.39852542, 0.3433386 ],\n [0.3891532 , 0.43035895, 0.36975437],\n [0.39519945, 0.43837112, 0.37691107],\n ...,\n [0.3303221 , 0.3882214 , 0.35804674],\n [0.33242843, 0.3934024 , 0.36258858],\n [0.32387304, 0.3849951 , 0.35492 ]]]], dtype=float32) - array([[[[0.35543454, 0.43019783, 0.40921336],\n [0.31809983, 0.38492018, 0.37182832],\n [0.25450823, 0.30702895, 0.30853915],\n ...,\n [0.04654467, 0.052351 , 0.05342975],\n [0.03162915, 0.03727475, 0.04003203],\n [0.01347888, 0.01817295, 0.02352586]],\n\n [[0.1730583 , 0.23201925, 0.23939878],\n [0.2390745 , 0.28689414, 0.2911203 ],\n [0.25934544, 0.29290092, 0.29849303],\n ...,\n [0.05556524, 0.06171474, 0.06239212],\n [0.03671473, 0.042806 , 0.04533792],\n [0.0194602 , 0.02420181, 0.02881676]],\n\n [[0.0672251 , 0.10336265, 0.1252414 ],\n [0.09461787, 0.1190142 , 0.13421047],\n [0.11416206, 0.12114635, 0.12741864],\n ...,\n [0.06349647, 0.07114592, 0.07095438],\n [0.04084024, 0.04772845, 0.04900536],\n [0.02486485, 0.03056636, 0.03394422]],\n\n ...,\n\n [[0.28480178, 0.3502748 , 0.26854983],\n [0.2973205 , 0.3628511 , 0.28085446],\n [0.30718493, 0.37491232, 0.2938819 ],\n ...,\n [0.22277448, 0.28664824, 0.23053554],\n [0.23811212, 0.30335784, 0.24710912],\n [0.27237236, 0.33945996, 0.28436485]],\n\n [[0.3109886 , 0.37897027, 0.30076852],\n [0.32570934, 0.39490467, 0.31523412],\n [0.3365014 , 0.40780032, 0.32922202],\n ...,\n [0.30866194, 0.37968448, 0.3247339 ],\n [0.31797737, 0.39205205, 0.33615708],\n [0.32898226, 0.4053657 , 0.3506784 ]],\n\n [[0.37926152, 0.44420728, 0.37584388],\n [0.39026803, 0.45716512, 0.38597545],\n [0.3961464 , 0.46525922, 0.3938964 ],\n ...,\n [0.3597461 , 0.43349507, 0.3782764 ],\n [0.364377 , 0.4419588 , 0.38566667],\n [0.35777402, 0.4380801 , 0.38149506]]]], dtype=float32)))
E + where <ufunc 'absolute'> = np.abs
tests/test_pipelines.py:1064: AssertionError
=========================== short test summary info ============================
FAILED tests/test_pipelines.py::PipelineTesterMixin::test_ddpm_ddim_equality
============================== 1 failed in 43.19s ==============================
System Info
diffusers
version: 0.4.0.dev0- Platform: Linux-5.10.133±x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.14
- PyTorch version (GPU?): 1.12.1+cu113 (True)
- Huggingface_hub version: 0.9.1
- Transformers version: 4.22.1
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Issue Analytics
- State:
- Created a year ago
- Comments:5 (4 by maintainers)
Top GitHub Comments
Unfortunately I couldn’t reproduce the improvement that you saw with
use_clipped_model_output
@sgrigory. The batched comparison still suffers too (max difference at around 0.37). Could you confirm if it works with the latest versions of the schedulers? Curious if we had a brief moment when the schedulers really were compatible@anton-l, I’ve double-checked and it seems that with the latest version of the code the situation is the same: by passing
use_clipped_model_output=True
one can make DDIM and DDPM pipelines compatible. Here is a PR with proposed changes: #1069 . Please have a look and let me know, I might be missing something. The tests in questions were not executed in the CI yet, but I ran them on colab - see the screenshot in the PR description.