Pickling a jitted function
See original GitHub issuePossibly a silly question, but - is there a way to save a jitted function? The current compilation time for a specific function I’m using is more than a couple of hours, and I’d be very happy to save it for later use (later on the same machine, and ideally - on a different machine altogether; I guess this is problematic, since compilation depends on the machine?). I’ve tried to pickle it, and got: AttributeError: Can't pickle local object '_jit.<locals>.f_jitted'
. Is this somehow possible?
Issue Analytics
- State:
- Created 4 years ago
- Reactions:4
- Comments:18 (11 by maintainers)
Top Results From Across the Web
c - cache numba jitted function with arguments include ...
I have some functions like the below sample. The objective is to pass the first function as an argument to the second function...
Read more >Reproducibility of pickles - Support - Numba Discussion
Hi all, TL;DR I have an issue with a custom sklearn transformer class I wrote that uses numba.jit. Namely, when I pickle an...
Read more >numba/numba - Gitter
... are already available for the jitted function or if Numba needs to recompile. ... It seems deepcopy doesn't work with numba (...
Read more >pickle — Python object serialization — Python 3.11.1 ...
“Pickling” is the process whereby a Python object hierarchy is... ... The pickle module provides the following functions to make the pickling process...
Read more >Serialization semantics — PyTorch 1.13 documentation
torch.save() and torch.load() use Python's pickle by default, so you can also ... Python modules even have a function, load_state_dict() , to restore...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Very cool! Thanks for the example.
We can make the compile time invariant to
n_samps
by rolling a loop:You could imagine alternatives like only rolling up an inner loop that performs one sweep, or keeping that unrolled and rolling an outer loop over
n_samp
, but here I kept both flattened into one rolled loop. We could avoid doing the wastefulall_states = indexed_update(all_states, ...)
on every iteration, either by using a lax.cond or else keeping the outer loop unrolled and un-jitted (i.e. just roll one sweep, or a fixed number of sweeps, into a loop).I’m afraid this doesn’t much help how the compilation time scales with
N
though, since the loop overfuncs
is still unrolled in the energy calculation. To improve that we might need XLA to stop inlining all functions…What do you think?
Edit: The ‘solution’ I suggested below does not actually do anything helpful. See more details in https://github.com/google/jax/issues/679#issuecomment-707627988
I also came up against the problem of trying to pickle a jitted function when trying to run parallel MCMC chains using JAX to calculate the model functions and derivatives. In case it helps anyone else having this problem who comes across this issue, I found the following workaround.
While the inbuilt
multiprocessing
module has problems with pickling any transformed JAX functions due to the use of nested / lambda functions, themultiprocess
package which usesdill
to perform the pickling seems to be able to be used without problems with non-jitted JAX functions or functions that call not-jitted JAX functions. However when for example using a jitted JAX function as thefunc
argument ofmultiprocessing.Pool.map
we seem to get a deadlock.To get around this I found if you only apply the JIT transformation within the child process things work fine. Providing you can ensure the function to be parallelised / jitted is not called in the parent process first, you can simply replace any
jit
decorators with the following ‘delayed’ version that only applies the JIT transform on the first call to the function:We can then use this
delayed_jit
decorator in place ofjit
as in the following simple example