Support serialization/export of compiled computations
See original GitHub issueContext:
I would like to use JAX to express a bunch of tensor computations in numpy-ish syntax, but delay the actual execution of the computation until later – ideally by registering the compiled function as a function that could be looked up from a shared lib. (the function would need to be called from a c++ program / library).
My initial idea was to:
- use jax.numpy for describing the computations
- export the XLA HLO when jitting on materialized tensors with the shapes/types of interest
- compile the XLA into executable functions and link into an
.so
using the approach intensorflow/compiler/aot
Assuming this approach makes sense (Please let me know if there is a better way), could you let me know how I could extract the XLA HLO during that second step?
Issue Analytics
- State:
- Created 5 years ago
- Reactions:35
- Comments:22 (3 by maintainers)
Top Results From Across the Web
Google Code Archive - Google Code
erwin.coumans, Sep 14, 2011, 2425, make cppunit compile on Mac OSX ... add support for btScaledBvhTriangleMeshShape serialization (export and import th.
Read more >MPI Design Discussion - Developers - The Stan Forums
I will put together today a prototype of my idea. It will only involve special internal handling of data arguments within the map...
Read more >Chainer and ChainerX - Nvidia
compute the gradient. ... Serialization, export …… Everything is optimized for ... Easy custom kernel compiled at runtime. ✓ FP16 support ...
Read more >Pydantic will become even better - Python in Plain English
Instead of using Python (which tends to be slow for computation intense ... Built in JSON Support: pydantic V2 will be able to...
Read more >Programming Thread - LessWrong
We support LaTeX: Cmd-4 for inline, Cmd-M for block-level (Ctrl on Windows). ... A full build of my code compiles package-by-package, ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hey, I would be interested to contribute to this.
My use case is slightly different, as I simply want to be able to persistently save
jit
ted-functions. I can provide more details about why I think this is important.As far as I understand, there are two parts to achieve this:
python
part, where we have to replace the cache of_xla_callable
s with a persistent one. I believe this is easy to prototype. I attach below a duct-taped example, that saves the simplestjax
function. If you run the script twice, no tracing will happen in the second run. Naturally, however, when trying to serialise any non-trivial function, the following part becomes crucial.XLA
part that includes serialisation and deserialisation ofjaxlib.xla_extension.Executable
s. Such anExecutable
appears to be a PjRtStreamExecutorExecutable, that is created by a List ofLocalExecutable
s that in turn can be created byHloModule
s.HloModule
s can be serialised and deserialised via these methods, so it appears that we have everything we need!(?)@hawkinsp do these sound sensible to you?
Serialisation example:
Thanks for your interest in JAX!
Yes, I think something like this would make a lot of sense for, say, inference use cases that want to get Python out of the way. We’ve discussed things along these lines, but haven’t done anything concrete yet.
One idea would be to add a new Python API
jax.aot_compile
(probably not that exact name), which, rather than running the computation immediately as JIT does, writes a.so
file and.h
file to disk that you can link into your code (or whatever language headers/wrappers seem appropriate). I think we could definitely improve on the ergonomics oftensorflow/compiler/aot
!If you’d like to try prototyping something along these lines, you might start from the undocumented function
jax.xla_computation
(https://github.com/google/jax/blob/master/jax/api.py#L155) which returns aComputation
object from the XLA client. In particular, it has a methodGetSerializedProto()
that returns anxla.HloModule
proto containing the computation (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py#L720)PRs welcome!