question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

JAX Integration

This issue will be used as a tracker to integrate Stable Diffusion in JAX natively to diffusers. This will enable many cool use cases noteably running stable diffusion on a google colab.

General design:

We will make loosen the forced PyTorch dependency and instead force the user to either install PyTorch or JAX. Then we will mirror the following “base” classes to be JAX compatible:

ModelMixin: https://github.com/patil-suraj/stable-diffusion-jax/pull/10 we should add a FlaxModelMixin class here. FlaxDiffusionPipeline: https://github.com/huggingface/diffusers/blob/25a51b63ca75e1351069bee87a0fb3df5abb89c3/src/diffusers/pipeline_utils.py#L76 we should add a FlaxDiffusionPipeline here.

Note: ModelMixin should be made state-less by default. E.g. weights will not be saved. Also contrary to transformers should we maybe only work with flax.linen.Module classes here @patil-suraj - I don’t really think we need the UNetConditionModel and UNetConditionModule design here - we could just go for class UNetConditionModel(nn.Module): here and make sure everything stays stateless no?

TODO:

Happy to take over 1. and finish today and then look into 4. once 3. is done.

@mishig25 do you want to do 2.? (happy to guide you here a bit if you have questions. Also we need to discuss the design here a bit offline maybe)

  1. & 5. @pcuenca do you want to take this? (think 3. is more important here)

The other parts we can see tomorrow maybe 😃

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:4
  • Comments:10 (10 by maintainers)

github_iconTop GitHub Comments

5reactions
patrickvonplatencommented, Sep 12, 2022

Asked Flax team about design here: https://github.com/google/flax/discussions/2454

3reactions
pcuencacommented, Oct 27, 2022

I think this should be complete now, please @patil-suraj @patrickvonplaten feel free to reopen otherwise.

Read more comments on GitHub >

github_iconTop Results From Across the Web

numpy.trapz() - JAX documentation - Read the Docs
Integrate along the given axis using the composite trapezoidal rule. ... If x is provided, the integration happens in sequence along its elements...
Read more >
`scipy.integrate` · Issue #9014 · google/jax - GitHub
As an example use case, I provide a simple implementation of a failed attempt to use JAX to differentiate through an improper integral...
Read more >
Differentiation of an improper integral using JAX and SciPy
I provide a simple code example of a failed attempt to use JAX to automatically differentiate through an improper integral function making ...
Read more >
JAX-FLUIDS: A fully-differentiable high-order computational ...
the seamless integration of machine learning models into solver frameworks,. 4. fully-differentiable algorithms which allow end-to-end ...
Read more >
Service integration technologies and JAX-RPC handlers - IBM
A JAX-RPC handler is a Java class that performs a range of handling tasks.For example: logging messages, or transforming their contents, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found