question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Allow non-default schedulers to be easily swapped into `DiffusionPipeline` classes

See original GitHub issue

Is your feature request related to a problem? Please describe.

By default the stable diffusion pipeline uses the PNDM scheduler, but one could easily use other schedulers (we only need to overwrite the self.scheduler) attribute.

This can be done with the following code-snippet:

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)  # make sure you're logged in with `huggingface-cli login`

# use DDIM scheduler
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)
pipe.scheduler = scheduler

Now, that’s a bit hacky and not the way we want users to do it ideally!

Describe the solution you’d like

Instead, the following code snippet should work or less for all pipelines:

from diffusers import StableDiffusionPipeline, DDIMScheduler

# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", scheduler=scheduler, use_auth_token=True)  # make sure you're logged in with `huggingface-cli login`

This is a cleaner & more intuitive API. The idea should be that every class variable that can be passed to https://github.com/huggingface/diffusers/blob/051b34635fda2fc310898a6a602c89be8663b77f/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L15 should also be overwrite-able when using from_pretrained(...)

When currently running this command it fails:

TypeError: cannot unpack non-iterable DDIMScheduler object 

Now we can allow such behavior by adding some logic to the general DiffusionPipeline from_pretrained method here: https://github.com/huggingface/diffusers/blob/051b34635fda2fc310898a6a602c89be8663b77f/src/diffusers/pipeline_utils.py#L115

Also we want this approach to work not just for one pipeline and only the scheduler class, but for all pipelines and all schedulers classes. We can achieve this by doing more or less the following in https://github.com/huggingface/diffusers/blob/051b34635fda2fc310898a6a602c89be8663b77f/src/diffusers/pipeline_utils.py#L115

Pseudo code:

  1. Retrieve all variables that can be passed to the class init here: https://github.com/huggingface/diffusers/blob/051b34635fda2fc310898a6a602c89be8663b77f/src/diffusers/pipeline_utils.py#L149 -> you should get a list of keys such as [vae, text_encoder, tokenizer, unet, scheduler]
  2. Check if any of those parameters are passed in kwargs -> if yes -> store them in a dict passed_class_obj
  3. In the loop that loads the class variables: https://github.com/huggingface/diffusers/blob/051b34635fda2fc310898a6a602c89be8663b77f/src/diffusers/pipeline_utils.py#L162 add a new if statements that checkes whether the name is in passed_class_obj dict -> if yes -> simple use this instead and skip the loading part (set the passed class to loaded_sub_model )

=> after the PR this should work:

from diffusers import StableDiffusionPipeline, DDIMScheduler

# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", scheduler=scheduler, use_auth_token=True)  # make sure you're logged in with `huggingface-cli login`

where as this should give a nice error message (note how scheduler is incorrectly passed to vae):

from diffusers import StableDiffusionPipeline, DDIMScheduler

# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", vae=scheduler, use_auth_token=True)  # make sure you're logged in with `huggingface-cli login`

The error message can be based on the passed class not having a matching parent class with what was expected (this could be checked using this dict: https://github.com/huggingface/diffusers/blob/051b34635fda2fc310898a6a602c89be8663b77f/src/diffusers/pipeline_utils.py#L34 )

Additional context As suggusted by @apolinario - it’s very important to allow one to easily swap out schedulers. At the same time we don’t want to create too much costum code. IMO the solution above handles the problem nicely.

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:12 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
patrickvonplatencommented, Aug 16, 2022

@patil-suraj nono it shouldn’t init the pipeline again - if implemented correctly there is no need to re-initialize anything. I’m not in favor of a set_scheduler method as this opens the door for many problems (e.g. we would adapt the model_index config after init which is not a good idea IMO

Read more comments on GitHub >

github_iconTop Results From Across the Web

Loading Pipelines, Models, and Schedulers - Hugging Face
In the following we explain in-detail how to easily load: ... The DiffusionPipeline class is the easiest way to access any diffusion model...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found