MPS crash when using LMSDiscreteScheduler
See original GitHub issueDescribe the bug
If you run the following code using LMSDiscreteScheduler
a crash occurs under Apple Silicon/MPS:
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from PIL import Image
from torchvision import transforms as tfms
# Set device
torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps else "cpu"
# Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(torch_device)
# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").to(torch_device)
# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
def pil_to_latent(input_im):
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
with torch.no_grad():
latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
return 0.18215 * latent.latent_dist.sample()
# Load the image with PIL
input_image = Image.open('macaw.jpg').resize((512, 512))
# Encode to the latent space
encoded = pil_to_latent(input_image)
# Setting the number of sampling steps:
scheduler.set_timesteps(15)
noise = torch.randn_like(encoded)
sampling_step = 10
encoded_and_noised = scheduler.add_noise(encoded, noise, timesteps=torch.tensor([scheduler.timesteps[sampling_step]]))
The crash is as follows:
Traceback (most recent call last):
File "/Users/fahim/Code/Python/fastai/diffusion-nbs/t.py", line 36, in <module>
encoded_and_noised = scheduler.add_noise(encoded, noise, timesteps=torch.tensor([scheduler.timesteps[sampling_step]]))
File "/Users/fahim/miniconda3/envs/ml/lib/python3.9/site-packages/diffusers/schedulers/scheduling_lms_discrete.py", line 255, in add_noise
self.timesteps = self.timesteps.to(original_samples.device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
Reproduction
Code is provided above
Logs
Traceback (most recent call last):
File "/Users/fahim/Code/Python/fastai/diffusion-nbs/t.py", line 36, in <module>
encoded_and_noised = scheduler.add_noise(encoded, noise, timesteps=torch.tensor([scheduler.timesteps[sampling_step]]))
File "/Users/fahim/miniconda3/envs/ml/lib/python3.9/site-packages/diffusers/schedulers/scheduling_lms_discrete.py", line 255, in add_noise
self.timesteps = self.timesteps.to(original_samples.device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
System Info
Python 3.9.13 diffusers 0.6.0 torch 1.14.0.dev20221021 macOS 12.6 Apple M1 Max 32GB
Issue Analytics
- State:
- Created a year ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
MSP Crash Updates
Click Current Incidents to view all incidents from the past week. Search. Click Search to search for other incidents. Info. Information about how...
Read more >Utilize Apple M1 chip causes error (kernel death) #13
On an M1 (not M1 Max) I get TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support...
Read more >UD-10 Traffic Crash Reporting - State of Michigan
Introducing the Michigan Crash Analysis Tool (Mi-CAT) Application. The Michigan State Police, Traffic Crash Reporting Unit, has contracted with vendor, Numetric ...
Read more >UDOT Traffic (@UDOTTRAFFIC) / Twitter
Crash NB I-15 at MP 182 (6 mi S of Scipio) Millard Co. Est. Clearance Time: 8:30 PM For ... Crash WB I-215...
Read more >Portland man killed in pileup involving nearly 65 vehicles on I ...
SB I-5 remains closed between OR 34 and MP 211 from multiple crashes. Delay travel SB or use extreme caution on alt routes....
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Thanks @pcuenca 🙂 I can confirm that it works correctly now.
Should be fixed now.