question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Feature request: optimize permutable nested prange

See original GitHub issue

Feature request

I have a function with three nested range loops. I optimize it using @njit(parallel=True).

The size of the ranges really depends, and some can be trivial (range(1)).

Doing some testing I observed that the only way to parallelize all three is to use prange on each.

But now, I observe that permuting the three loops changes the performance (my loop sizes are very unequal, one range is ~2000 and the other ~20), which is understandable (I think only the outermost prange is used).

What do you recommend? The best solution would be to support a way for numba to understand that you can permute the loops depending on their size / for cache hit optimization (think generalized prange). Numba could maybe even “test” the order and select the best one, while running (not sure it makes any sense).

Nested loops have this structure:

for i1 in prange(n1):
    # code before
    for i2 in prange(n2):
        # recursive structure
    # code after

If there is no code before or after or if it is fast, supporting a permutable version of itertools.product would be sufficient (like a pproduct).

for i1, i2, i3 in pproduct(range(n1), range(n2), range(n3)):
    # code before 1
    # code before 2
    # code before 3
    # code inside
    # code after 3
    # code after 2
    # code after 1

Else, one should ensure that the “code before” in prange(n2) doesn’t use i1 by computing a dependency graph or providing a syntax for the user to specify it. For example:

l1, l2, l3 = permutable(range(n1), range(n2), range(n3))
for i1 in l1:
    # code before
    for i2 in l2:
        # recursive structure
    # code after

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Reactions:1
  • Comments:8 (5 by maintainers)

github_iconTop GitHub Comments

2reactions
stuartarchibaldcommented, May 14, 2018

Thanks for the request. Extracting the code from https://stackoverflow.com/questions/50255126/numba-doesnt-parallelize-range below.

from numba import njit, prange
from time import time


@njit
def f1(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s


@njit
def f2(n):
    s = 0
    for i in prange(n):
        for j in prange(n):
            for k in prange(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s


@njit(parallel=True)
def f3(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s


@njit(parallel=True)
def f4(n):
    s = 0
    for i in prange(n):
        for j in prange(n):
            for k in prange(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s


for f in [f1, f2, f3, f4]:
    d = time()
    f(2500)
    print('%.02f' % (time() - d))

As Numba is using JIT compilation, the timing part of the code above is including the compilation time as well as the execution time in the reported time. For reference, the above gives me:

21.39
21.18
21.20
14.00

editing the timing part of the code so that it is just timing execution:

for f in [f1, f2, f3, f4]:
    f(10) # invoke once with the type used in the timed section to trigger compilation
    d = time()
    f(2500)
    print('%.02f' % (time() - d))

gives:

21.09
20.92
20.92
12.71

it doesn’t make much difference as the compute part of the code is quite heavy for the given n, but it should always be considered.

As to the four variants of the loop and what is observed in the timings:

@njit
def f1(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s

Function f1 is a standard CPU compiled triple nested loop.

@njit
def f2(n):
    s = 0
    for i in prange(n):
        for j in prange(n):
            for k in prange(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s

Function f2 has loops declared with prange, but no parallel=True option set in njit, as a result the compiler sees prange as an alias of range.

@njit(parallel=True)
def f3(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s

Function f3 has parallel=True so analysis for transforming the code to execute in parallel will happen, however, the analysis correctly decides that there is nothing to parallelize (no loop was declared as parallel with prange). The example referred to in the parallel documentation that contains range is parallelizing the loop body which contains many computations on arrays, it is these that are fused and transformed into a parallel region. In the example it is also worth nothing that w is a loop carried dependency (iteration i+1 needs the result from iteration i) and so an embarrassingly parallel loop execution is not possible.

@njit(parallel=True)
def f4(n):
    s = 0
    for i in prange(n):
        for j in prange(n):
            for k in prange(n):
                s += (i * k < j * j) - (i * k > j * j)
    return s

Function f4 has parallel=True and the loops are all declared with prange, this allows analysis for transforming the code to execute in parallel and as there are explicitly declared parallel loops that are suitable for parallel transformation the transform is done and the code runs more quickly.

Declaring prange on inner loops when there is an outer loop prange translates to the inner ones being run as range loops, this prevents nested parallelism and also makes it such that larger work blocks are available per thread.

The information about what parallel=True is doing can be found by setting the environment variable NUMBA_DEBUG_ARRAY_OPT_STATS, with this set the terminal states:

Function f3 has no Parfor.
Parallel for-loop #0 is produced from pattern '('prange', 'user')' at issue2960.py (41)
Parallel for-loop #1 is produced from pattern '('prange', 'user')' at issue2960.py (40)
Parallel for-loop #2 is produced from pattern '('prange', 'user')' at issue2960.py (39)
After fusion, parallel for-loop After fusion has 1 nested Parfor(s) #1.
After fusion, parallel for-loop After fusion has 2 nested Parfor(s) #1.
After fusion, function f4 has 1 parallel for-loop(s) #{2}.

Which shows f3 has no parfor transform, and f4 has 3 parallel loops identified from prange and it fuses them into a single loop (loop 2, the outer one).

In answer to your feature request… at present Numba compiles code based on the type of the arguments and not the values, it also compiles everything to machine code upfront and dispatches to compiled code based purely on type. The behaviour described in the feature request requires analysis based on run time values and so is more amenable to a tracing JIT which could feasibly analyse a loop nest instance at run time and perform dynamic loop nest optimisations. However, this is out of the scope of what Numba can do at present, I would think https://github.com/numba/numba/issues/2949 will help towards being able to achieve this though.

This all said, I think you will get loop nest optimisation from the LLVM backed via e.g. loop switching if you use loop bounds that are fixed at compile time, e.g. if you declare your loops having a fixed sized like range(20).

1reaction
stuartarchibaldcommented, Dec 15, 2021

Closing this question as it seems to be resolved. Numba now has a discourse forum https://numba.discourse.group/ which is great for questions like this, please do consider posting there in future 😃 Thanks!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Making four nested for loops even faster with Numba
Is there any way to optimize my code for better performance? python · loops · iteration · jit · numba · Share.
Read more >
The Annotated DETR | Committed towards better future
The new model is conceptually simple and does not require a specialized library, unlike many other modern detectors. DETR demonstrates accuracy ...
Read more >
Comparing hyperparameter optimization frameworks in Python
There are several HPO frameworks implemented in Python that support methods that are able to handle nested search spaces. Why bother? Excellent ...
Read more >
Decision Optimization - IBM Community
Creating an ndarray from ragged nested sequences(which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths / shapes)is deprecated ...
Read more >
A framework for feature selection in clustering - PMC - NCBI
Clustering methods require some concept of the dissimilarity between pairs ... The nested models are sparse in the features, and so this yields...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found