Unable to jit logpdf
See original GitHub issueDescription of the bug
Hi @wesselb,
I am trying to write some GP code in JAX and accelerate it with jax.jit
, but it is failing due to a numpy conversion happening in the process. A potential solution seems to comment out code checking for NaN values in logpdf function (and it works), but you can suggest a better solution for this. Also, chex
mentions that it allows testing code with and without jitting; it could be used in testing at some point in the future.
Code
import jax
import jax.numpy as jnp
from stheno.jax import GP, EQ
x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = jax.jit(jax.grad(loss_fn))
grad_fn(lengthscale)
Output
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)
45 frames
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
530 device=device, backend=backend, name=flat_fun.__name__,
--> 531 donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
532 out_pytree_def = out_tree()
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, fun, *args, **params)
1962 def bind(self, fun, *args, **params):
-> 1963 return call_bind(self, fun, *args, **params)
1964
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in call_bind(primitive, fun, *args, **params)
1978 fun_ = lu.annotate(fun_, fun.in_type)
-> 1979 outs = top_trace.process_call(primitive, fun_, tracers, params)
1980 return map(full_lower, apply_todos(env_trace_todo(), outs))
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_call(self, primitive, f, tracers, params)
688 def process_call(self, primitive, f, tracers, params):
--> 689 return primitive.impl(f, *tracers, **params)
690 process_map = process_call
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl(***failed resolving arguments***)
233 compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
--> 234 keep_unused, *arg_specs)
235 try:
[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
294 else:
--> 295 ans = call(fun, *args)
296 cache[key] = (ans, fun.stores)
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
324 return lower_xla_callable(fun, device, backend, name, donated_invars, False,
--> 325 keep_unused, *arg_specs).compile().unsafe_call
326
[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
312 with TraceAnnotation(name, **decorator_kwargs):
--> 313 return func(*args, **kwargs)
314 return wrapper
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs)
400 jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
--> 401 fun, pe.debug_info_final(fun, "jit"))
402 out_avals, kept_outputs = util.unzip2(out_type)
[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
312 with TraceAnnotation(name, **decorator_kwargs):
--> 313 return func(*args, **kwargs)
314 return wrapper
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_final2(fun, debug_info)
2024 with core.new_sublevel():
-> 2025 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
2026 del fun, main
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
1974 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 1975 ans = fun.call_wrapped(*in_tracers_)
1976 out_tracers = map(trace.full_raise, ans)
[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
167 try:
--> 168 ans = self.f(*args, **dict(self.params, **kwargs))
169 except:
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in grad_f(*args, **kwargs)
1002 def grad_f(*args, **kwargs):
-> 1003 _, g = value_and_grad_f(*args, **kwargs)
1004 return g
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in value_and_grad_f(*args, **kwargs)
1078 if not has_aux:
-> 1079 ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
1080 else:
[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _vjp(fun, has_aux, reduce_axes, *primals)
2497 out_primal, out_vjp = ad.vjp(
-> 2498 flat_fun, primals_flat, reduce_axes=reduce_axes)
2499 out_tree = out_tree()
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in vjp(traceable, primals, has_aux, reduce_axes)
132 if not has_aux:
--> 133 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
134 else:
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in linearize(traceable, *primals, **kwargs)
121 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 122 jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
123 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
312 with TraceAnnotation(name, **decorator_kwargs):
--> 313 return func(*args, **kwargs)
314 return wrapper
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_nounits(fun, pvals, instantiate)
768 fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 769 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
770 assert not env
[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
167 try:
--> 168 ans = self.f(*args, **dict(self.params, **kwargs))
169 except:
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
10 grad_fn = jax.jit(jax.grad(loss_fn))
[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
261 if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262 available = B.jit_to_numpy(~B.isnan(x[:, 0]))
263 if not B.all(available):
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
1532 else:
-> 1533 res = B.to_numpy(*args)
1534 if B.control_flow.caching:
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
1496 """
-> 1497 return convert(a, NPOrNum)
1498
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
31 """
---> 32 return _convert.invoke(type_of(obj), type_to)(obj, type_to)
33
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
606 def wrapped_method(*args, **kw_args):
--> 607 return _convert(method(*args, **kw_args), return_type)
608
[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
60 def perform_conversion(obj: type_from, _: type_to):
---> 61 return f(obj)
62
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
535 def __array__(self, *args, **kw):
--> 536 raise TracerArrayConversionError(self)
537
UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TracerArrayConversionError Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
7 y = jnp.arange(10)
8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
10 grad_fn = jax.jit(jax.grad(loss_fn))
11 grad_fn(lengthscale)
[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
260 # Handle missing data. We don't handle missing data for batched computation.
261 if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262 available = B.jit_to_numpy(~B.isnan(x[:, 0]))
263 if not B.all(available):
264 # Take the elements of the mean, variance, and inputs corresponding to
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
582 # to speed up the common case.
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
586 return _convert(method(*args, **kw_args), return_type)
[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
1531 return B.control_flow.get_outcome("to_numpy")
1532 else:
-> 1533 res = B.to_numpy(*args)
1534 if B.control_flow.caching:
1535 B.control_flow.set_outcome("to_numpy", res)
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
582 # to speed up the common case.
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
586 return _convert(method(*args, **kw_args), return_type)
[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
1495 `np.ndarray`: `a` as NumPy.
1496 """
-> 1497 return convert(a, NPOrNum)
1498
1499
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
582 # to speed up the common case.
583 if return_type is default_obj_type:
--> 584 return method(*args, **kw_args)
585 else:
586 return _convert(method(*args, **kw_args), return_type)
[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
30 object: `obj` converted to type `type_to`.
31 """
---> 32 return _convert.invoke(type_of(obj), type_to)(obj, type_to)
33
34
[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
605 @wraps(self._f)
606 def wrapped_method(*args, **kw_args):
--> 607 return _convert(method(*args, **kw_args), return_type)
608
609 return wrapped_method
[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
59 @_convert.dispatch
60 def perform_conversion(obj: type_from, _: type_to):
---> 61 return f(obj)
62
63
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Description of your environment
Tried this in Google colab.
Issue Analytics
- State:
- Created a year ago
- Comments:9 (9 by maintainers)
Top Results From Across the Web
Speed up loop... · Discussion #9467 · google/jax - GitHub
fori_loop works? But I think there is a sequential dependency here. Maybe you just need to combine get_params&value_and_grad&update in one function and jit...
Read more >What is the reason for Unable to JIT error - Okta Support
Upon SAML authentication against external IdP I am getting redirected to the login page and this error appears the log Authenticate user via...
Read more >JIT unable to improve my JAX code: where am I wrong?
Here is a simple JAX code to show the Metropolis Algorithm in action to solve a 3 parameters bayesian regression pb. Running wo...
Read more >Changelog - iMinuit
Minuit.mncontour used to fail if called twice in a row ... Updated tutorial about automatic differentiation, added comparison of numba.njit and jax.jit ......
Read more >Designing modular inference engines: API for the HMC kernel
The logpdf can be JIT-compiled with JAX and used for batched inference. ... it fails to accomodate more complex schemes like NUTS.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I see! Hmm, this might be challenging. Dispatch currently heavily leverages types, and the type of a PyTree is somewhat troublesome. You’re right that
jaxtyping
offers a PyTree type, but that type seems to only perform instance checking rather than containing the recursive type definition that we would like. I’ll have to think about this! It agree that it would be super useful to support PyTrees.I think that it would be possible to convert the static number to a dynamic depth. However, perhaps the right solution here is to see if we can actually give PyTrees first-class support. I’ll soon be working on a 2.0 of Plum, which is where currently the restrictions derive from. I will put PyTree support on the list of desired improvements!