Launch CUDA kernels from cfuncs
See original GitHub issueFeature request
I would like a nopython @cfunc
to be able to launch a @cuda.jit
kernel.
import numba.types as nt
@numba.cuda.jit(nt.void(nt.CPointer(nt.float32), nt.CPointer(nt.float32))
def kernel_gpu(input, output):
i = numba.cuda.grid(1)
output[i] = input[i] + 1
@numba.cfunc(nt.void(nt.voidptr, nt.CPointer(nt.voidptr), nt.CPointer(nt.char), nt.ulong))
def kernel_launcher(stream, buffers, opaque, opaque_len):
blockspergrid, threadsperblock = some_calculation(opaque, opaque_len)
kernel_gpu[blockspergrid, threadsperblock, stream](buffers[0], buffers[1])
I’ve been doing some experiments in JAX issue 1870 which is about allowing Numba CUDA kernels to be consumed by JAX’s XLA JIT. Most of the pieces are there to put this together, but launching CUDA kernels is proving problematic.
We can already integrate Numba CPU kernels; it’s simply a matter of creating a @cfunc
with the right signature and patching it into the JAX API.
The most basic thing we would need to be able to do to launch a Numba CUDA kernel is get its handle. I found kernel_gpu[1,10]._func.get().handle
which I thought might be a sufficient hacky way of doing that, but it turns out that that broke sometime between Numba 0.48 and 0.52. It’s clear that that isn’t a public API, so JAX shouldn’t consume it or whatever the 0.52 equivalent is.
I believe that that means that this JAX feature cannot happen without a Numba change to expose an appropriate API.
If a @cfunc
were able to launch a CUDA kernel then all of this handle wrangling could be avoided. Both the CPU and GPU kernels could be written as simple @cfunc
s; the GPU ones would simply launch a CUDA kernel. This would allow JAX to neatly integrate with Numba without having to write CUDA code to handle launching.
@seibert mentioned on the mailing list that this is a feature Numba would like to implement, and mentioned that it may also have the advantage of reducing repeated kernel launch overhead.
Issue Analytics
- State:
- Created 3 years ago
- Comments:9 (5 by maintainers)
@SamPruden Have you come across https://render.githubusercontent.com/view/ipynb?color_mode=light&commit=330e2fd1dd6cbc799aadd0d40e8c6391ebbaadf1&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f676d61726b616c6c2f6c6966652d6f662d612d6e756d62612d6b65726e656c2f333330653266643164643663626337393961616464306434306538633633393165626261616466312f4c6966652532306f66253230612532304e756d62612532304b65726e656c2532302d253230776974682d2532306f75747075742e6970796e62&nwo=gmarkall%2Flife-of-a-numba-kernel&path=Life+of+a+Numba+Kernel+-+with-+output.ipynb&repository_id=256233857&repository_type=Repository#Loading-the-module already? Some of the details in it from “Loading the Module” onwards might be helpful for your hacking.
I think that providing a supported way to get the kernel handle will be acceptable, but with the caveat that the ABI of compiled kernels may not be stable across releases (though it hasn’t really changed much over time, to my knowledge).
Great, let us know how you get on!
Probably a question for @gmarkall as code owner/maintainer for the CUDA target.
No reason, I just forgot to type it 😃