Use consistent random number generation across hardware
See original GitHub issueIs your feature request related to a problem? Please describe.
pytorch.randn
is not consistent across hardware devices (See https://github.com/pytorch/pytorch/issues/84234).
diffusers
calls torch.randn
on the device computation is run on (typically ‘cuda’). As a result, results produced with the exact same parameters will differ across machines.
Describe the solution you’d like
Until the issue is resolved in pytorch itself, diffusers
should use a deterministic RNG so results can be consistent across hardware.
One possible workaround is to keep using torch.rng
while enforcing generation to happen on the cpu, which currently seems consistent no matter the hardware.
Here is an example solution:
def randn(size, generator=None, device=None, **kwargs):
"""
Wrapper around torch.randn providing proper reproducibility.
Generation is done on the given generator's device, then moved to the
given ``device``.
Args:
size: tensor size
generator (torch.Generator): RNG generator
device (torch.device): Target device for the resulting tensor
"""
# FIXME: generator RNG device is ignored and needs to be passed to torch.randn (torch issue #62451)
rng_device = generator.device if generator is not None else device
image = torch.randn(size, generator=generator, device=rng_device, **kwargs)
image = image.to(device=device)
return image
def randn_like(tensor, generator=None, **kwargs):
return randn(tensor.shape, layout=tensor.layout, generator=generator, device=tensor.device, **kwargs)
Calling these functions instead of the torch
ones, with a generator whose device is cpu
, gives deterministic results and still allows for the rest of the computations to run on cuda.
This would also simplify and speed up all the tests, which can simply use cpu
-bound generators and leave device
to be cuda
even for those relying on RNG.
Describe alternatives you’ve considered
It’s also possisble to switch to numpy
’s RNG, which is deterministic. The above solution is more torch
-native.
Issue Analytics
- State:
- Created 10 months ago
- Reactions:1
- Comments:7 (7 by maintainers)
Top GitHub Comments
Reproducibility/determinism is not just a "nice thing, it’s vitally important for multiple reasons:
Agree with Patrick here, IMO it’s better to add a method which can enable reproducibility. I’m usually hesitant to create such wrappers around framework functions as it might make it a bit harder to go through the code. For example, if we replace
torch.randn
by customrandn
then users who are going through the code might wonder that something different is happening here.