question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Support serialization/export of compiled computations

See original GitHub issue

Context:

I would like to use JAX to express a bunch of tensor computations in numpy-ish syntax, but delay the actual execution of the computation until later – ideally by registering the compiled function as a function that could be looked up from a shared lib. (the function would need to be called from a c++ program / library).

My initial idea was to:

  • use jax.numpy for describing the computations
  • export the XLA HLO when jitting on materialized tensors with the shapes/types of interest
  • compile the XLA into executable functions and link into an .so using the approach in tensorflow/compiler/aot

Assuming this approach makes sense (Please let me know if there is a better way), could you let me know how I could extract the XLA HLO during that second step?

Issue Analytics

  • State:open
  • Created 5 years ago
  • Reactions:35
  • Comments:22 (3 by maintainers)

github_iconTop GitHub Comments

12reactions
nrontsiscommented, Feb 15, 2021

Hey, I would be interested to contribute to this.

My use case is slightly different, as I simply want to be able to persistently save jitted-functions. I can provide more details about why I think this is important.

As far as I understand, there are two parts to achieve this:

  • The python part, where we have to replace the cache of _xla_callables with a persistent one. I believe this is easy to prototype. I attach below a duct-taped example, that saves the simplest jax function. If you run the script twice, no tracing will happen in the second run. Naturally, however, when trying to serialise any non-trivial function, the following part becomes crucial.
  • The XLA part that includes serialisation and deserialisation of jaxlib.xla_extension.Executables. Such an Executable appears to be a PjRtStreamExecutorExecutable, that is created by a List of LocalExecutables that in turn can be created by HloModules. HloModules can be serialised and deserialised via these methods, so it appears that we have everything we need!(?)

@hawkinsp do these sound sensible to you?

Serialisation example:
from copy import copy
from jax import tree_flatten
from jax.linear_util import WrappedFun, Store
import jaxlib.xla_extension
import dill as pickle
import jax.numpy as np
from jax.api import jit, vmap
from jax.interpreters import xla
import jax._src.util
from jax.lazy import ArrayVar


# Hack to avoid pickling error: <class 'jax._src.util.ArrayVar'>: it's not found as jax._src.util.ArrayVar
jax._src.util.ArrayVar = ArrayVar


# Hacks to Serialise PyTrees
class PythonTree:
    def __init__(self, definition):
        self.definition = definition


class PyTreeStar:
    __repr__ = lambda _: "*"


def pytree_to_serialisable(obj):
    if isinstance(obj, jaxlib.xla_extension.PyTreeDef):
        return PythonTree(obj.unflatten(obj.num_leaves * [PyTreeStar(), ]))
    else:
        return obj


def serialisable_to_pytree(obj):
    if hasattr(obj, "definition"):
        return tree_flatten(obj.definition)[1]
    else:
        return obj


STATIC_COMPILATION_IDENTIFIER = "_compiled_statically_"

original_xla_callable = copy(xla._xla_callable)


def xla_callable(fun: WrappedFun, device, backend, name, donated_invars, *arg_specs):
    if name[:len(STATIC_COMPILATION_IDENTIFIER)] != STATIC_COMPILATION_IDENTIFIER:
        return original_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)

    filename = name + ".dill"
    try:
        compiled_function, store_values = pickle.load(open(filename, "rb"))
        stores = tuple([Store() for _ in store_values])
        for store, value in zip(stores, store_values):
            store.store(serialisable_to_pytree(value))
        fun.populate_stores(stores)
    except IOError:
        compiled_function = original_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
        store_values = [pytree_to_serialisable(store.val) for store in fun.stores]
        pickle.dump((compiled_function, store_values), open(filename, "wb"))

    return compiled_function


xla._xla_callable = xla_callable
xla._xla_callable.most_recent_entry = lambda: None


def persistent_jit(function, unique_name):
    f = copy(function)
    f.__name__ = STATIC_COMPILATION_IDENTIFIER + unique_name
    jitted_function = jit(f)
    return jitted_function


def my_function(x):
    print("tracing")
    return x  # This works
    # return np.sum(np.square(x))  # This doesn't


compiled_function = persistent_jit(vmap(my_function), unique_name="my_first_aot_compiled_function")
print(compiled_function(np.zeros(3)))
print(compiled_function(np.ones(3)))
8reactions
hawkinspcommented, Mar 4, 2019

Thanks for your interest in JAX!

Yes, I think something like this would make a lot of sense for, say, inference use cases that want to get Python out of the way. We’ve discussed things along these lines, but haven’t done anything concrete yet.

One idea would be to add a new Python API jax.aot_compile (probably not that exact name), which, rather than running the computation immediately as JIT does, writes a .so file and .h file to disk that you can link into your code (or whatever language headers/wrappers seem appropriate). I think we could definitely improve on the ergonomics of tensorflow/compiler/aot!

If you’d like to try prototyping something along these lines, you might start from the undocumented function jax.xla_computation (https://github.com/google/jax/blob/master/jax/api.py#L155) which returns a Computation object from the XLA client. In particular, it has a method GetSerializedProto() that returns an xla.HloModule proto containing the computation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py#L720)

PRs welcome!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Google Code Archive - Google Code
erwin.coumans, Sep 14, 2011, 2425, make cppunit compile on Mac OSX ... add support for btScaledBvhTriangleMeshShape serialization (export and import th.
Read more >
MPI Design Discussion - Developers - The Stan Forums
I will put together today a prototype of my idea. It will only involve special internal handling of data arguments within the map...
Read more >
Chainer and ChainerX - Nvidia
compute the gradient. ... Serialization, export …… Everything is optimized for ... Easy custom kernel compiled at runtime. ✓ FP16 support ...
Read more >
Pydantic will become even better - Python in Plain English
Instead of using Python (which tends to be slow for computation intense ... Built in JSON Support: pydantic V2 will be able to...
Read more >
Programming Thread - LessWrong
We support LaTeX: Cmd-4 for inline, Cmd-M for block-level (Ctrl on Windows). ... A full build of my code compiles package-by-package, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found