JAX Integration
See original GitHub issueJAX Integration
This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to diffusers
. This will enable many cool use cases noteably running stable diffusion on a google colab.
General design:
We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following “base” classes to be JAX compatible:
ModelMixin
: https://github.com/patil-suraj/stable-diffusion-jax/pull/10 we should add a FlaxModelMixin
class here.
FlaxDiffusionPipeline
: https://github.com/huggingface/diffusers/blob/25a51b63ca75e1351069bee87a0fb3df5abb89c3/src/diffusers/pipeline_utils.py#L76 we should add a FlaxDiffusionPipeline
here.
Note: ModelMixin
should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers
should we maybe only work with flax.linen.Module
classes here @patil-suraj - I don’t really think we need the UNetConditionModel
and UNetConditionModule
design here - we could just go for class UNetConditionModel(nn.Module):
here and make sure everything stays stateless no?
TODO:
- 1. Make
diffusers
framework independent. This will require some general changes tosetup.py
and our automation tools - 2. Add
FlaxModelMixin
: https://github.com/huggingface/diffusers/pull/493 Here we can take a lot from https://github.com/patil-suraj/stable-diffusion-jax/pull/10/files but I’m not sure we should follow thetransformers
design here 1-to-1 . Will also ask some google-folks here - 3. Add all the modeling code under
unet_2d_condition_flax.py
… - #478
- 5. Add PNDM scheduler under
scheduling_pndm_flax.py
- 6. Tests
- 7. Create pipeline and also
FlaxDiffusionPipeline
Happy to take over 1. and finish today and then look into 4. once 3. is done.
@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)
- & 5. @pcuenca do you want to take this? (think 3. is more important here)
The other parts we can see tomorrow maybe 😃
Issue Analytics
- State:
- Created a year ago
- Reactions:4
- Comments:10 (10 by maintainers)
Top GitHub Comments
Asked Flax team about design here: https://github.com/google/flax/discussions/2454
I think this should be complete now, please @patil-suraj @patrickvonplaten feel free to reopen otherwise.