question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

thanks (it's 10x faster than JAX)!

See original GitHub issue

I’ve been trying to get dalle-playground running performantly on M1, but there’s a lot of work remaining to make the JAX model work via IREE/Vulkan.

so, I tried out your pytorch model,

with a recent nightly of pytorch:

pip install --pre "torch>1.13.0.dev20220610" "torchvision>0.14.0.dev20220609" --extra-index-url https://download.pytorch.org/whl/nightly/cpu

…and it’s 10x faster at dalle-mega than dalle-playground was on JAX/XLA!

using dalle-mega full:

wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1:latest

generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)!
GPU looks less-than-half utilized. haven’t checked whether pytorch is the process that’s using the GPU.

these measurements are from M1 Max.

bonus
“crystal maiden and lina enjoying a pint together at a tavern”
generated

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:10
  • Comments:14 (2 by maintainers)

github_iconTop GitHub Comments

3reactions
woctezumacommented, Aug 5, 2022

Even faster these days: you get a 4x4 grid instead of a 3x3 grid on Replicate, after the same duration.

However, this is based on Dall-E MEGA instead of Dall-E Mini, so results might differ. Not sure if better or worse.

0reactions
Birch-sancommented, Jul 2, 2022

I also tried using .contiguous() on any tensor that would be transferred to the MPS device:
https://github.com/Birch-san/min-dalle/commit/b1cf6c284a949d23f8c0cd6802bb207b876bf2af

still black.

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
For example, if we switch this example to use 10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs)....
Read more >
Performance comparison - Flux.jl's Adam vs Jax's Adam
Hello guys, I have been wondering about Julia's performance and would like to know if it is really true that Julia is really...
Read more >
Does JAX run slower than NumPy? - python - Stack Overflow
For individual matrix operations on CPU, JAX is often slower than NumPy, but JIT-compiled sequences of operations in JAX are often faster ......
Read more >
Five Star Pizza – Bigger better faster
The order taker was 10/10, pizza maker 10/10, the delivery driver 12/10 (so awesome ... the food AND service is better here than...
Read more >
Enabling Fast Differentially Private SGD via Just-in-Time ...
of this algorithm is 10-100x slower than standard training,” where their implementation is based on. TensorFlow Privacy.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found