backward pass of scan is very slow to compile in CPU
See original GitHub issueThe following script took more than 1 minute to compile in my system (with the newest versions of jax & jaxlib). A very strange phenomenon is: compiling time will be reduced to 2s
for some random changes (e.g. replacing 0.5 * level
by level
) in the body function.
import jax.numpy as np
from jax import jit, grad, lax
from jax.config import config; config.update("jax_platform_name", "cpu")
def f(s):
def scan_fn(carry, y_t):
level, s = carry
exp_val = (level + 2 * level ** 0.5) * s[0]
exp_val = np.clip(exp_val, a_min=0)
level = y_t / s[0] + 0.5 * level
s = y_t / level + 0.5 * s
return (level, s), exp_val
return lax.scan(scan_fn, (1., s), np.ones(80))[1].sum()
jit(grad(f))(np.ones(2))
Issue Analytics
- State:
- Created 4 years ago
- Comments:10 (8 by maintainers)
Top Results From Across the Web
PostgreSQL very slow index scan - Stack Overflow
Run through an index on posted_date in reverse order, and nest join using article_id until you locate 20 matches — scanning an enormous...
Read more >compiling suddenly extremely slow - Arduino Forum
Hi, I was using Arduino IDE for about two year on this computer and everything was fine. But last week, the compilation started...
Read more >Incredibly Slow Extremely Annoying [Compile Time - XC8]
Hello, I started experiencing this problem a short while ago. Firstly everything is ok, and the program is compiling in seconds as would...
Read more >Slow Compile Times After the First Compile - NI Community
The strange thing is when I first start my computer and compile, it's quite fast with a compile time of about 30 seconds...
Read more >C++ is too slow to compile, can you share all your tips to lower ...
Link-Time Code Generation OFF. Profile-Guided build OFF. For release versions, you actually want all these options in reverse - primarily for ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This looks like it’s either an LLVM bug or an XLA bug. The computation isn’t all that large, but lots of time is being spent in the LLVM SelectionDAG logic. I filed an (internal) bug for the XLA team.
We can probably hope for XLA/LLVM to get faster, but also we can improve the kind of graph we’re staging out.
I noticed (because @hawkinsp pointed it out in related circumstances) that we’re adding a lot of scalars literals to the loop-carry tuple, when instead those could be staged into the XLA computation as literals. @dougalm and I recently added some logic to instantiate scalar literals as XLA literals rather than hosting them into the loop carry (in d27bc0a, part of #704), but we conservatively only switched that on for the Python types
int
andfloat
. In particular,DeviceArray
constants still got hoisted.I noticed that this computation was doing a good amount of scalar-hoisting, and so in #780 (specifically 9c931dd) I sketched out some logic that allows more types to be treated as literals in jaxprs (and hence in the staged-out XLA computations). That seems to make the compile time for the OP’s code essentially instantaneous.
I want to look over that code with fresh eyes tomorrow, but I’m optimistic it (or something like it) will handle this scan compilation time issue and a lot of related ones.