question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Pickling a jitted function

See original GitHub issue

Possibly a silly question, but - is there a way to save a jitted function? The current compilation time for a specific function I’m using is more than a couple of hours, and I’d be very happy to save it for later use (later on the same machine, and ideally - on a different machine altogether; I guess this is problematic, since compilation depends on the machine?). I’ve tried to pickle it, and got: AttributeError: Can't pickle local object '_jit.<locals>.f_jitted'. Is this somehow possible?

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:4
  • Comments:18 (11 by maintainers)

github_iconTop GitHub Comments

4reactions
mattjjcommented, May 10, 2019

Very cool! Thanks for the example.

We can make the compile time invariant to n_samps by rolling a loop:

import jax.numpy as np
import jax.random as random
from jax import jit,vmap
from jax.ops import index_update
from jax.lax import fori_loop
import numpy as onp
from functools import partial
import itertools as it

N = 30

marg_1 = lambda i,x:x[i]
marg_2 = lambda i,j,x:x[i]*x[j]

marg_1s = [jit(partial(marg_1,i)) for i in range(N)]
marg_2s = [jit(partial(marg_2,i,j)) for i,j in list(it.combinations(range(N),r=2))]
funcs = marg_1s+marg_2s

@jit
def calc_e(factors,word):
    return np.sum(factors*np.array([func(word) for func in funcs]))

factors = np.array(onp.random.randn(len(funcs)))

def sample(key, n_samps):
    state = random.randint(key,minval=0,maxval=2, shape=(N,))
    unifs = random.uniform(key, shape=(n_samps*N,))

    def run_mh(j, loop_carry):
      state, all_states = loop_carry
      all_states = index_update(all_states,j//N,state)  # a bit wasteful
      state_flipped = index_update(state,j%N,1-state[j%N])
      dE = calc_e(factors,state_flipped)-calc_e(factors,state)
      accept = ((dE < 0) | (unifs[j] < np.exp(-dE)))
      state = np.where(accept, state_flipped, state)
      return state, all_states

    all_states = np.zeros((n_samps,N))
    all_states = fori_loop(0, n_samps * N, run_mh, (state, all_states))
    return all_states

sample = jit(sample, static_argnums=(1,))

You could imagine alternatives like only rolling up an inner loop that performs one sweep, or keeping that unrolled and rolling an outer loop over n_samp, but here I kept both flattened into one rolled loop. We could avoid doing the wasteful all_states = indexed_update(all_states, ...) on every iteration, either by using a lax.cond or else keeping the outer loop unrolled and un-jitted (i.e. just roll one sweep, or a fixed number of sweeps, into a loop).

I’m afraid this doesn’t much help how the compilation time scales with N though, since the loop over funcs is still unrolled in the energy calculation. To improve that we might need XLA to stop inlining all functions…

What do you think?

2reactions
matt-grahamcommented, Oct 13, 2020

Edit: The ‘solution’ I suggested below does not actually do anything helpful. See more details in https://github.com/google/jax/issues/679#issuecomment-707627988


I also came up against the problem of trying to pickle a jitted function when trying to run parallel MCMC chains using JAX to calculate the model functions and derivatives. In case it helps anyone else having this problem who comes across this issue, I found the following workaround.

While the inbuilt multiprocessing module has problems with pickling any transformed JAX functions due to the use of nested / lambda functions, the multiprocess package which uses dill to perform the pickling seems to be able to be used without problems with non-jitted JAX functions or functions that call not-jitted JAX functions. However when for example using a jitted JAX function as the func argument of multiprocessing.Pool.map we seem to get a deadlock.

To get around this I found if you only apply the JIT transformation within the child process things work fine. Providing you can ensure the function to be parallelised / jitted is not called in the parent process first, you can simply replace any jit decorators with the following ‘delayed’ version that only applies the JIT transform on the first call to the function:

from jax import api

def delayed_jit(func, *jit_args, **jit_kwargs):
    
    jitted_func = None
    
    def wrapped(*args, **kwargs):
        nonlocal jitted_func
        if jitted_func is None:
            jitted_func = api.jit(func, *jit_args, **jit_kwargs)
        return jitted_func(*args, **kwargs)
    
    return wrapped

We can then use this delayed_jit decorator in place of jit as in the following simple example

from multiprocess import Pool
import numpy as onp
import jax.numpy as np

@delayed_jit
def norm(x):
    return np.sum(x**2)**0.5

grad_norm = delayed_jit(api.grad(norm))

rng = onp.random.RandomState(1234)
vectors = rng.standard_normal((100, 10))
pool = Pool(4)
norm_vectors = pool.map(norm, vectors)
grad_norm_vectors = pool.map(grad_norm, vectors)
Read more comments on GitHub >

github_iconTop Results From Across the Web

c - cache numba jitted function with arguments include ...
I have some functions like the below sample. The objective is to pass the first function as an argument to the second function...
Read more >
Reproducibility of pickles - Support - Numba Discussion
Hi all, TL;DR I have an issue with a custom sklearn transformer class I wrote that uses numba.jit. Namely, when I pickle an...
Read more >
numba/numba - Gitter
... are already available for the jitted function or if Numba needs to recompile. ... It seems deepcopy doesn't work with numba (...
Read more >
pickle — Python object serialization — Python 3.11.1 ...
“Pickling” is the process whereby a Python object hierarchy is... ... The pickle module provides the following functions to make the pickling process...
Read more >
Serialization semantics — PyTorch 1.13 documentation
torch.save() and torch.load() use Python's pickle by default, so you can also ... Python modules even have a function, load_state_dict() , to restore...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found