question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

backward pass of scan is very slow to compile in CPU

See original GitHub issue

The following script took more than 1 minute to compile in my system (with the newest versions of jax & jaxlib). A very strange phenomenon is: compiling time will be reduced to 2s for some random changes (e.g. replacing 0.5 * level by level) in the body function.

import jax.numpy as np
from jax import jit, grad, lax
from jax.config import config; config.update("jax_platform_name", "cpu")

def f(s):
    def scan_fn(carry, y_t):
        level, s = carry
        exp_val = (level + 2 * level ** 0.5) * s[0]
        exp_val = np.clip(exp_val, a_min=0)
        level = y_t / s[0] + 0.5 * level
        s = y_t / level + 0.5 * s
        return (level, s), exp_val

    return lax.scan(scan_fn, (1., s), np.ones(80))[1].sum()

jit(grad(f))(np.ones(2))

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:10 (8 by maintainers)

github_iconTop GitHub Comments

3reactions
hawkinspcommented, May 28, 2019

This looks like it’s either an LLVM bug or an XLA bug. The computation isn’t all that large, but lots of time is being spent in the LLVM SelectionDAG logic. I filed an (internal) bug for the XLA team.

1reaction
mattjjcommented, May 29, 2019

We can probably hope for XLA/LLVM to get faster, but also we can improve the kind of graph we’re staging out.

I noticed (because @hawkinsp pointed it out in related circumstances) that we’re adding a lot of scalars literals to the loop-carry tuple, when instead those could be staged into the XLA computation as literals. @dougalm and I recently added some logic to instantiate scalar literals as XLA literals rather than hosting them into the loop carry (in d27bc0a, part of #704), but we conservatively only switched that on for the Python types int and float. In particular, DeviceArray constants still got hoisted.

I noticed that this computation was doing a good amount of scalar-hoisting, and so in #780 (specifically 9c931dd) I sketched out some logic that allows more types to be treated as literals in jaxprs (and hence in the staged-out XLA computations). That seems to make the compile time for the OP’s code essentially instantaneous.

I want to look over that code with fresh eyes tomorrow, but I’m optimistic it (or something like it) will handle this scan compilation time issue and a lot of related ones.

Read more comments on GitHub >

github_iconTop Results From Across the Web

PostgreSQL very slow index scan - Stack Overflow
Run through an index on posted_date in reverse order, and nest join using article_id until you locate 20 matches — scanning an enormous...
Read more >
compiling suddenly extremely slow - Arduino Forum
Hi, I was using Arduino IDE for about two year on this computer and everything was fine. But last week, the compilation started...
Read more >
Incredibly Slow Extremely Annoying [Compile Time - XC8]
Hello, I started experiencing this problem a short while ago. Firstly everything is ok, and the program is compiling in seconds as would...
Read more >
Slow Compile Times After the First Compile - NI Community
The strange thing is when I first start my computer and compile, it's quite fast with a compile time of about 30 seconds...
Read more >
C++ is too slow to compile, can you share all your tips to lower ...
Link-Time Code Generation OFF. Profile-Guided build OFF. For release versions, you actually want all these options in reverse - primarily for ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found