question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Generic support for JIT compilation with custom NumPy ops

See original GitHub issue

It would be great to be able to jit functions that make use of custom CPU operation, i.e., implemented with NumPy arrays. This would be a really valuable extension point for integrating JAX with existing codes/algorithms, and would possibly solve the final remaining use-cases for autograd.

Right now, you can use custom CPU operations if you don’t jit, but that adds a very large amount of dispatch overhead.

My understanding is this could be possible by making use of XLA’s CustomCall support.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:3
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

5reactions
shoyercommented, Jan 31, 2021

I think we can consider this fixed by the new (experimental) version of host_callback.call, e.g.,

import jax.experimental.host_callback as hcb
import jax.numpy as jnp
import jax

def myprint(x):
  print('inside myprint:', type(x), x)
  return x

@jax.jit
def device_fun(x):
  return hcb.call(myprint, x, result_shape=x)

device_fun(jnp.arange(10))

Prints: inside myprint: <class 'numpy.ndarray'> [0 1 2 3 4 5 6 7 8 9]

Please give this a try and file new issues CCing @gnecula if you encounter any issues!

2reactions
Thenerdstationcommented, Nov 30, 2020

Bump on this.

We would like to be able to call python code like how tf.py_func works. I understand that type/shape inference is one of the blockers here, but having some kind of “manual” typing would work for us.

So something like

def f(a: np.array) -> np.array:
  res = ...# some non-jax math
  return res

typed_f = jax.give_signature(
  function=f,
  input=Shape((1,), jnp.float32), 
  output=Shape((1,), jnp.float32))

@jax.jit
def g(a)
  return typed_f(a) # Should work assuming our earlier code worked
Read more comments on GitHub >

github_iconTop Results From Across the Web

Types and signatures - Numba
Exactly which kind of signature is allowed depends on the context (AOT or JIT compilation), but signatures always involve some representation of Numba...
Read more >
Speed Up your Algorithms Part 2— Numba | by Puneet Grover
With Numba, you can speed up all of your calculation focused and computationally heavy python functions(eg loops). It also has support for numpy...
Read more >
Just-In-Time Compilation of NumPy Vector Operations
In this paper, we introduce JIT compilation for the high-productivity framework Python/NumPy in order to boost the performance significantly ...
Read more >
A Map of the Numba Repository
To help orient developers, this document will try to summarize where different ... JIT compilation of Python classes; numba/core/generators.py - Support for ...
Read more >
Getting Started with NumPyro
Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. Docs and Examples | Forum. What is NumPyro?¶. NumPyro...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found