question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Function accepting njitted functions as arguments is slow

See original GitHub issue

I was trying numba 0.38 and the new support for jitted functions as arguments with this code snippet:

# coding: utf-8
from scipy.optimize import newton
from numba import njit
@njit
def func(x):
    return x**3 - 1
@njit
def fprime(x):
    return 3 * x**2
@njit
def njit_newton(func, x0, fprime):
    for _ in range(50):
        fder = fprime(x0)
        fval = func(x0)
        newton_step = fval / fder
        x = x0 - newton_step
        if abs(x - x0) < 1.48e-8:
            return x
        x0 = x
            
get_ipython().run_line_magic('timeit', 'newton(func.py_func, 1.5, fprime=fprime.py_func)')
get_ipython().run_line_magic('timeit', 'newton(func, 1.5, fprime=fprime)')
get_ipython().run_line_magic('timeit', 'njit_newton.py_func(func, 1.5, fprime=fprime)')
get_ipython().run_line_magic('timeit', 'njit_newton(func, 1.5, fprime=fprime)')

And I found surprising that njit_newton is the slowest of all, while njit_newton.py_func is the fastest:

$ ipython test_perf.py 
4.76 µs ± 8.52 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
4.14 µs ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.58 µs ± 26 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
20 µs ± 85.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

cc @nikita-astronaut

(Inspiration: https://github.com/scipy/scipy/blob/607a21e07dad234f8e63fcf03b7994137a3ccd5b/scipy/optimize/zeros.py#L164-L182)

Issue Analytics

  • State:open
  • Created 5 years ago
  • Reactions:1
  • Comments:21 (15 by maintainers)

github_iconTop GitHub Comments

1reaction
Acmioncommented, Aug 11, 2022

I can confirm that this issue exists. However, as mentioned above, the issue does in fact seem to be caused by a cost when calling Numba jitted code from Python.

import numpy as np
import numba as nb

@nb.njit
def foo(x):
    return x

@nb.njit
def foo_bad(x, func):
    return x

@nb.njit
def foo_bad_alt(x, func):
    return func(x)

@nb.njit
def bar(x):
    s = 0.0

    for v in x:
        s += v

    return s

@nb.njit
def bar_bad(x, func):
    s = 0.0

    for v in x:
        s += v

    return s

@nb.njit
def bar_bad_alt(x, func):
    return func(x)
%timeit foo(10)
%timeit foo_bad(10, foo)
%timeit foo_bad_alt(10, foo)

# 167 ns ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
# 17.5 µs ± 1.54 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# 15.7 µs ± 591 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
data = np.random.random(1000000)

%timeit bar(data)
%timeit bar_bad(data, bar)
%timeit bar_bad_alt(data, bar)

# 1.02 ms ± 57.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 1.1 ms ± 65.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 1.05 ms ± 29.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

The difference in performance when comparing the foo functions is great, however, since timeit is called from the Python context these timings are largely affected by Numba invokation costs.

The difference in performance when comparing the bar functions is minimal, because now most of the time is actually spent in the function and not in interfacing between Numba and Python.

1reaction
luk-f-acommented, Jul 31, 2019

For reference, if the functions do any real work, the differences disappear (and strangely reverse, which I cannot explain)

import numba                                                                                                                                                                                                                          

from numba import njit                                                                                                                                                                                                                

@njit 
def foo(x): 
    a = 0.
    for i in range(10000000):
        a += i
    return a
                                                                                                                                                                                                                                       

def bar(x, f): 
    a = 0.
    for i in range(10000000):
        a += i
    return a
                                                                                                                                                                                                                                       

bar_jit = njit(bar)                                                                                                                                                                                                                   
foo(1)          
%timeit foo(1)                                                                                                                                                                                                                        
#26.6 ms ± 2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

bar_jit(1, foo)   
%timeit bar_jit(1, foo)                                                                                                                                                                                                               
#25.7 ms ± 2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Read more comments on GitHub >

github_iconTop Results From Across the Web

Why does using arguments make this function so much slower?
By clicking “Accept all cookies”, you agree Stack Exchange can store cookies on your device and disclose information in accordance with our ...
Read more >
10x slowdown when passing function as argument
I'm trying to understand why passing a function as an argument is sometimes 10x slower ... The problem is that functions are automatically...
Read more >
How much do function calls impact performance?
When functions are not inlined, yes there is a performance hit to make a ... let the "slow code" tell you what it...
Read more >
How not to be slow using Python: Functions - pawroman.dev
The function using comprehensions wins big time, taking less than half of the mean run time compared to functional approach. The "naive" loop ......
Read more >
Programming FAQ — Python 3.11.1 documentation
How do I use strings to call functions/methods? ... My program is too slow. ... Parameters define what kind of arguments a function...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found