thanks (it's 10x faster than JAX)!
See original GitHub issueI’ve been trying to get dalle-playground running performantly on M1, but there’s a lot of work remaining to make the JAX model work via IREE/Vulkan.
so, I tried out your pytorch model,
with a recent nightly of pytorch:
pip install --pre "torch>1.13.0.dev20220610" "torchvision>0.14.0.dev20220609" --extra-index-url https://download.pytorch.org/whl/nightly/cpu
…and it’s 10x faster at dalle-mega than dalle-playground was on JAX/XLA!
using dalle-mega full:
wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1:latest
generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)!
GPU looks less-than-half utilized. haven’t checked whether pytorch is the process that’s using the GPU.
these measurements are from M1 Max.
bonus
“crystal maiden and lina enjoying a pint together at a tavern”
Issue Analytics
- State:
- Created a year ago
- Reactions:10
- Comments:14 (2 by maintainers)
Top Results From Across the Web
JAX Frequently Asked Questions (FAQ)
For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs)....
Read more >Performance comparison - Flux.jl's Adam vs Jax's Adam
Hello guys, I have been wondering about Julia's performance and would like to know if it is really true that Julia is really...
Read more >Does JAX run slower than NumPy? - python - Stack Overflow
For individual matrix operations on CPU, JAX is often slower than NumPy, but JIT-compiled sequences of operations in JAX are often faster ......
Read more >Five Star Pizza – Bigger better faster
The order taker was 10/10, pizza maker 10/10, the delivery driver 12/10 (so awesome ... the food AND service is better here than...
Read more >Enabling Fast Differentially Private SGD via Just-in-Time ...
of this algorithm is 10-100x slower than standard training,” where their implementation is based on. TensorFlow Privacy.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Even faster these days: you get a 4x4 grid instead of a 3x3 grid on Replicate, after the same duration.
However, this is based on Dall-E MEGA instead of Dall-E Mini, so results might differ. Not sure if better or worse.
I also tried using
.contiguous()
on any tensor that would be transferred to the MPS device:https://github.com/Birch-san/min-dalle/commit/b1cf6c284a949d23f8c0cd6802bb207b876bf2af
still black.