[Community Discussion] Scheduler design
See original GitHub issueThis is a very nice PR by @pcuenca showing what changes need to be done to the PNDM/PMLS scheduler to make it work with JAX/XLA - it’s actually more then anticipated and shows that the scheduler now substantially differs from the original implementation that we use for PyTorch.
In the coming days we will integrate these changes into main diffusers
to make the library compatible with Flax/JAX. Now the big question is should we:
a) Make each scheduler
very generic and continue the set_format("pt")
logic? While this would make sense logically as the schedulers don’t store any trainable weights really - this could potentially lead to quite some if - else
statements and too much abstracted code, e.g. lots of self.where(...)
functions in scheduler_utils.py
. Also maybe we want schedulers to have trainable weights in the future? Also do we anticipate schedulers to be more or less complex in the future?
b) Make one scheduler file for each framework. Instead of trying to fit all frameworks into one scheduler file, we make one scheduler for one framework. The advantage is clearly readability. Also most people probably always only work in one framework so for them it might be nicer to have schedulers seperate. However: Some schedulers will probably be 1-to-1 the same (which also might not be a problem necessarily)
I’m starting to be lean more and more towards b) actually here.
Would love to discuss - cc @anton-l @patil-suraj @natolambert @pcuenca
Issue Analytics
- State:
- Created a year ago
- Comments:9 (8 by maintainers)
Top GitHub Comments
I agree with your points, creating a generic scheduler here, will complicate a the overall design. In JAX we need to take special care of state, keeping everything scheduler as jnp arrays, avoid device to host communication etc. So I’m also in favor of option b. B is good for both readability and maintainability.
Cool design taken -> we’ll have framework-dependent schedulers 😃