CUDA fastmath: use fast math trig / exp / log / fdivide functions
See original GitHub issue(cc @mnicely who I think will have an interest in the resolution of this)
Compiling the following with nvcc and fast math flags:
#include <math.h>
__global__ void f(float* r, float x)
{
r[0] = cos(x);
}
(using nvcc --std=c++11 --generate-code arch=compute_75,code=sm_75 --use_fast_math test.cu --ptx -o test_fast.ptx
)
results in the following PTX:
.visible .entry _Z1fPff(
.param .u64 _Z1fPff_param_0,
.param .f32 _Z1fPff_param_1
)
{
.reg .f32 %f<3>;
.reg .b64 %rd<3>;
ld.param.u64 %rd1, [_Z1fPff_param_0];
ld.param.f32 %f1, [_Z1fPff_param_1];
cvta.to.global.u64 %rd2, %rd1;
cos.approx.ftz.f32 %f2, %f1;
st.global.f32 [%rd2], %f2;
ret;
}
However, the following CUDA Python:
from numba import config, cuda, float32, void
from math import cos
config.DUMP_ASSEMBLY = True
@cuda.jit(void(float32[::1], float32), fastmath=True)
def f(r, x):
r[0] = cos(x)
produces a lot of code:
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__errcode__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__tidx__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__ctaidx__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__tidy__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__ctaidy__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__tidz__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__ctaidz__;
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf;
.const .align 4 .b8 __cudart_i2opi_f[24] = {65, 144, 67, 60, 153, 149, 98, 219, 192, 221, 52, 245, 209, 87, 39, 252, 41, 21, 68, 78, 110, 131, 249, 162};
.visible .entry _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf(
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_0,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_1,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_2,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_3,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_4,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_5,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_6,
.param .f32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_7
)
{
.local .align 4 .b8 __local_depot0[28];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .pred %p<12>;
.reg .f32 %f<38>;
.reg .b32 %r<66>;
.reg .f64 %fd<3>;
.reg .b64 %rd<17>;
mov.u64 %SPL, __local_depot0;
ld.param.u64 %rd7, [_ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_4];
ld.param.f32 %f14, [_ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_7];
add.u64 %rd1, %SPL, 0;
mul.ftz.f32 %f15, %f14, 0f3F22F983;
cvt.rni.ftz.s32.f32 %r65, %f15;
cvt.rn.f32.s32 %f16, %r65;
mov.f32 %f17, 0fBFC90FDA;
fma.rn.ftz.f32 %f18, %f16, %f17, %f14;
mov.f32 %f19, 0fB3A22168;
fma.rn.ftz.f32 %f20, %f16, %f19, %f18;
mov.f32 %f21, 0fA7C234C5;
fma.rn.ftz.f32 %f35, %f16, %f21, %f20;
abs.ftz.f32 %f2, %f14;
setp.leu.ftz.f32 %p1, %f2, 0f47CE4780;
@%p1 bra BB0_11;
setp.eq.ftz.f32 %p2, %f2, 0f7F800000;
@%p2 bra BB0_10;
bra.uni BB0_2;
BB0_10:
mov.f32 %f24, 0f00000000;
mul.rn.ftz.f32 %f35, %f14, %f24;
bra.uni BB0_11;
BB0_2:
mov.b32 %r2, %f14;
shr.u32 %r3, %r2, 23;
bfe.u32 %r4, %r2, 23, 8;
shl.b32 %r33, %r2, 8;
or.b32 %r5, %r33, -2147483648;
mov.u32 %r59, 0;
mov.u64 %rd15, __cudart_i2opi_f;
mov.u32 %r58, -6;
mov.u64 %rd16, %rd1;
BB0_3:
.pragma "nounroll";
ld.const.u32 %r36, [%rd15];
// inline asm
{
mad.lo.cc.u32 %r34, %r36, %r5, %r59;
madc.hi.u32 %r59, %r36, %r5, 0;
}
// inline asm
st.local.u32 [%rd16], %r34;
add.s64 %rd16, %rd16, 4;
add.s64 %rd15, %rd15, 4;
add.s32 %r58, %r58, 1;
setp.ne.s32 %p3, %r58, 0;
@%p3 bra BB0_3;
add.s32 %r39, %r4, -128;
shr.u32 %r40, %r39, 5;
and.b32 %r10, %r2, -2147483648;
st.local.u32 [%rd1+24], %r59;
mov.u32 %r41, 6;
sub.s32 %r42, %r41, %r40;
mul.wide.s32 %rd10, %r42, 4;
add.s64 %rd6, %rd1, %rd10;
ld.local.u32 %r61, [%rd6];
ld.local.u32 %r60, [%rd6+-4];
and.b32 %r13, %r3, 31;
setp.eq.s32 %p4, %r13, 0;
@%p4 bra BB0_6;
mov.u32 %r43, 32;
sub.s32 %r44, %r43, %r13;
shr.u32 %r45, %r60, %r44;
shl.b32 %r46, %r61, %r13;
add.s32 %r61, %r45, %r46;
ld.local.u32 %r47, [%rd6+-8];
shr.u32 %r48, %r47, %r44;
shl.b32 %r49, %r60, %r13;
add.s32 %r60, %r48, %r49;
BB0_6:
shr.u32 %r50, %r60, 30;
shl.b32 %r51, %r61, 2;
add.s32 %r63, %r51, %r50;
shl.b32 %r19, %r60, 2;
shr.u32 %r52, %r63, 31;
shr.u32 %r53, %r61, 30;
add.s32 %r20, %r52, %r53;
setp.eq.s32 %p5, %r52, 0;
@%p5 bra BB0_7;
not.b32 %r54, %r63;
neg.s32 %r62, %r19;
setp.eq.s32 %p6, %r19, 0;
selp.u32 %r55, 1, 0, %p6;
add.s32 %r63, %r55, %r54;
xor.b32 %r64, %r10, -2147483648;
bra.uni BB0_9;
BB0_7:
mov.u32 %r62, %r19;
mov.u32 %r64, %r10;
BB0_9:
cvt.u64.u32 %rd11, %r63;
cvt.u64.u32 %rd12, %r62;
bfi.b64 %rd13, %rd11, %rd12, 32, 32;
cvt.rn.f64.s64 %fd1, %rd13;
mul.f64 %fd2, %fd1, 0d3BF921FB54442D19;
cvt.rn.ftz.f32.f64 %f22, %fd2;
neg.ftz.f32 %f23, %f22;
setp.eq.s32 %p7, %r64, 0;
selp.f32 %f35, %f22, %f23, %p7;
setp.eq.s32 %p8, %r10, 0;
neg.s32 %r56, %r20;
selp.b32 %r65, %r20, %r56, %p8;
BB0_11:
add.s32 %r29, %r65, 1;
and.b32 %r30, %r29, 1;
setp.eq.s32 %p9, %r30, 0;
selp.f32 %f6, %f35, 0f3F800000, %p9;
mul.rn.ftz.f32 %f7, %f35, %f35;
mov.f32 %f26, 0f00000000;
fma.rn.ftz.f32 %f8, %f7, %f6, %f26;
mov.f32 %f36, 0fB94D4153;
@%p9 bra BB0_13;
mov.f32 %f27, 0fBAB607ED;
mov.f32 %f28, 0f37CBAC00;
fma.rn.ftz.f32 %f36, %f28, %f7, %f27;
BB0_13:
selp.f32 %f29, 0f3C0885E4, 0f3D2AAABB, %p9;
fma.rn.ftz.f32 %f30, %f36, %f7, %f29;
selp.f32 %f31, 0fBE2AAAA8, 0fBEFFFFFF, %p9;
fma.rn.ftz.f32 %f32, %f30, %f7, %f31;
fma.rn.ftz.f32 %f37, %f32, %f8, %f6;
and.b32 %r57, %r29, 2;
setp.eq.s32 %p11, %r57, 0;
@%p11 bra BB0_15;
mov.f32 %f34, 0fBF800000;
fma.rn.ftz.f32 %f37, %f37, %f34, %f26;
BB0_15:
cvta.to.global.u64 %rd14, %rd7;
st.global.f32 [%rd14], %f37;
ret;
}
This is the slow, accurate cos
implementation. It would be desirable for it to instead generate:
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__errcode__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__tidx__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__ctaidx__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__tidy__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__ctaidy__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__tidz__;
.visible .global .align 4 .u32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf__ctaidz__;
.common .global .align 8 .u64 _ZN08NumbaEnv8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf;
.visible .entry _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf(
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_0,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_1,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_2,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_3,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_4,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_5,
.param .u64 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_6,
.param .f32 _ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_7
)
{
.reg .f32 %f<3>;
.reg .b64 %rd<3>;
ld.param.u64 %rd1, [_ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_4];
ld.param.f32 %f1, [_ZN6cudapy8__main__5f$241E5ArrayIfLi1E1C7mutable7alignedEf_param_7];
cos.approx.ftz.f32 %f2, %f1;
cvta.to.global.u64 %rd2, %rd1;
st.global.f32 [%rd2], %f2;
ret;
}
There are several functions for which the approximate instruction can be used when fast math is turned on instead of the slow libdevice function - these are accessible through the __nv_fast_*
libdevice functions, which are:
- __nv_fast_cosf
- __nv_fast_exp10f
- __nv_fast_expf
- __nv_fast_fdividef
- __nv_fast_log10f
- __nv_fast_log2f
- __nv_fast_logf
- __nv_fast_powf
- __nv_fast_sincosf
- __nv_fast_sinf
- __nv_fast_tanf
So, when fastmath=True
for a kernel, calls to math.sin
, math.cos
, etc. should lower to calls to these functions instead.
Note it is also possible to use the instructions by using intrinsics like:
%3 = call float @llvm.nvvm.cos.approx.ftz.f(float %a)
in IR, but since these appear to be undocumented it’s probably better to use the call to libdevice that wraps them instead, as it won’t result in any difference in generated code compared to using the intrinsic directly.
The following change on top of PR #6152 is a quick hack that demonstrates the generation of these instructions using the libdevice fast functions, for math.sin
and math.cos
only:
diff --git a/numba/cuda/mathimpl.py b/numba/cuda/mathimpl.py
index a9c6dc65a..21eb4f659 100644
--- a/numba/cuda/mathimpl.py
+++ b/numba/cuda/mathimpl.py
@@ -29,11 +29,11 @@ unarys += [('log10', 'log10f', math.log10)]
unarys += [('log1p', 'log1pf', math.log1p)]
unarys += [('acosh', 'acoshf', math.acosh)]
unarys += [('acos', 'acosf', math.acos)]
-unarys += [('cos', 'cosf', math.cos)]
+unarys += [('cos', 'fast_cosf', math.cos)]
unarys += [('cosh', 'coshf', math.cosh)]
unarys += [('asinh', 'asinhf', math.asinh)]
unarys += [('asin', 'asinf', math.asin)]
-unarys += [('sin', 'sinf', math.sin)]
+unarys += [('sin', 'fast_sinf', math.sin)]
unarys += [('sinh', 'sinhf', math.sinh)]
unarys += [('atan', 'atanf', math.atan)]
unarys += [('atanh', 'atanhf', math.atanh)]
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:11 (11 by maintainers)
It would be nice to both be able to use the fast functions individually, and also to have the fastmath flag has a global effect. Exposing all the fast math functions under the
numba.cuda
(or maybenumba.cuda.math
) namespace would be handy.@testhound No problem - because it has to be handled in a different way to the math functions, it makes sense to consider it separately.