question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unable to jit logpdf

See original GitHub issue

Description of the bug

Hi @wesselb,

I am trying to write some GP code in JAX and accelerate it with jax.jit, but it is failing due to a numpy conversion happening in the process. A potential solution seems to comment out code checking for NaN values in logpdf function (and it works), but you can suggest a better solution for this. Also, chex mentions that it allows testing code with and without jitting; it could be used in testing at some point in the future.

Code

import jax
import jax.numpy as jnp

from stheno.jax import GP, EQ

x = jnp.arange(10)
y = jnp.arange(10)
lengthscale = jnp.array(1.0)
loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
grad_fn = jax.jit(jax.grad(loss_fn))
grad_fn(lengthscale)

Output

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
     10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)

45 frames
[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
    530         device=device, backend=backend, name=flat_fun.__name__,
--> 531         donated_invars=donated_invars, inline=inline, keep_unused=keep_unused)
    532     out_pytree_def = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, fun, *args, **params)
   1962   def bind(self, fun, *args, **params):
-> 1963     return call_bind(self, fun, *args, **params)
   1964 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in call_bind(primitive, fun, *args, **params)
   1978   fun_ = lu.annotate(fun_, fun.in_type)
-> 1979   outs = top_trace.process_call(primitive, fun_, tracers, params)
   1980   return map(full_lower, apply_todos(env_trace_todo(), outs))

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_call(self, primitive, f, tracers, params)
    688   def process_call(self, primitive, f, tracers, params):
--> 689     return primitive.impl(f, *tracers, **params)
    690   process_map = process_call

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl(***failed resolving arguments***)
    233   compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
--> 234                               keep_unused, *arg_specs)
    235   try:

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
    294     else:
--> 295       ans = call(fun, *args)
    296       cache[key] = (ans, fun.stores)

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
    324     return lower_xla_callable(fun, device, backend, name, donated_invars, False,
--> 325                               keep_unused, *arg_specs).compile().unsafe_call
    326 

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs)
    400     jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
--> 401         fun, pe.debug_info_final(fun, "jit"))
    402   out_avals, kept_outputs = util.unzip2(out_type)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_final2(fun, debug_info)
   2024     with core.new_sublevel():
-> 2025       jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2026     del fun, main

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   1974     in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 1975     ans = fun.call_wrapped(*in_tracers_)
   1976     out_tracers = map(trace.full_raise, ans)

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in grad_f(*args, **kwargs)
   1002   def grad_f(*args, **kwargs):
-> 1003     _, g = value_and_grad_f(*args, **kwargs)
   1004     return g

[/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in value_and_grad_f(*args, **kwargs)
   1078     if not has_aux:
-> 1079       ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
   1080     else:

[/usr/local/lib/python3.7/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _vjp(fun, has_aux, reduce_axes, *primals)
   2497     out_primal, out_vjp = ad.vjp(
-> 2498         flat_fun, primals_flat, reduce_axes=reduce_axes)
   2499     out_tree = out_tree()

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in vjp(traceable, primals, has_aux, reduce_axes)
    132   if not has_aux:
--> 133     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    134   else:

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py](https://localhost:8080/#) in linearize(traceable, *primals, **kwargs)
    121   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 122   jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
    123   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312     with TraceAnnotation(name, **decorator_kwargs):
--> 313       return func(*args, **kwargs)
    314     return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_nounits(fun, pvals, instantiate)
    768     fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 769     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    770     assert not env

[/usr/local/lib/python3.7/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:

[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
      8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))

[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
    261         if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262             available = B.jit_to_numpy(~B.isnan(x[:, 0]))
    263             if not B.all(available):

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
   1532     else:
-> 1533         res = B.to_numpy(*args)
   1534         if B.control_flow.caching:

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
   1496     """
-> 1497     return convert(a, NPOrNum)
   1498 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
     31     """
---> 32     return _convert.invoke(type_of(obj), type_to)(obj, type_to)
     33 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
    606         def wrapped_method(*args, **kw_args):
--> 607             return _convert(method(*args, **kw_args), return_type)
    608 

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
     60     def perform_conversion(obj: type_from, _: type_to):
---> 61         return f(obj)
     62 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in __array__(self, *args, **kw)
    535   def __array__(self, *args, **kw):
--> 536     raise TracerArrayConversionError(self)
    537 

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
    from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <module>
      9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))
---> 11 grad_fn(lengthscale)

[<ipython-input-6-9ed89be9714a>](https://localhost:8080/#) in <lambda>(lengthscale)
      7 y = jnp.arange(10)
      8 lengthscale = jnp.array(1.0)
----> 9 loss_fn = lambda lengthscale: GP(EQ().stretch(lengthscale))(x).logpdf(y)
     10 grad_fn = jax.jit(jax.grad(loss_fn))
     11 grad_fn(lengthscale)

[/usr/local/lib/python3.7/dist-packages/stheno/random.py](https://localhost:8080/#) in logpdf(self, x)
    260         # Handle missing data. We don't handle missing data for batched computation.
    261         if B.rank(x) == 2 and B.shape(x, 1) == 1:
--> 262             available = B.jit_to_numpy(~B.isnan(x[:, 0]))
    263             if not B.all(available):
    264                 # Take the elements of the mean, variance, and inputs corresponding to

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in jit_to_numpy(*args)
   1531         return B.control_flow.get_outcome("to_numpy")
   1532     else:
-> 1533         res = B.to_numpy(*args)
   1534         if B.control_flow.caching:
   1535             B.control_flow.set_outcome("to_numpy", res)

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/lab/generic.py](https://localhost:8080/#) in to_numpy(a)
   1495         `np.ndarray`: `a` as NumPy.
   1496     """
-> 1497     return convert(a, NPOrNum)
   1498 
   1499 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in __call__(self, *args, **kw_args)
    582             # to speed up the common case.
    583             if return_type is default_obj_type:
--> 584                 return method(*args, **kw_args)
    585             else:
    586                 return _convert(method(*args, **kw_args), return_type)

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in convert(obj, type_to)
     30         object: `obj` converted to type `type_to`.
     31     """
---> 32     return _convert.invoke(type_of(obj), type_to)(obj, type_to)
     33 
     34 

[/usr/local/lib/python3.7/dist-packages/plum/function.py](https://localhost:8080/#) in wrapped_method(*args, **kw_args)
    605         @wraps(self._f)
    606         def wrapped_method(*args, **kw_args):
--> 607             return _convert(method(*args, **kw_args), return_type)
    608 
    609         return wrapped_method

[/usr/local/lib/python3.7/dist-packages/plum/promotion.py](https://localhost:8080/#) in perform_conversion(obj, _)
     59     @_convert.dispatch
     60     def perform_conversion(obj: type_from, _: type_to):
---> 61         return f(obj)
     62 
     63 

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(bool[10])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function <lambda> at <ipython-input-6-9ed89be9714a>:9 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i64[10,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1)] b
    from line /usr/local/lib/python3.7/dist-packages/lab/jax/shaping.py:17 (_expand_dims)

  operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /usr/local/lib/python3.7/dist-packages/stheno/random.py:262 (logpdf)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Description of your environment

Tried this in Google colab.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:9 (9 by maintainers)

github_iconTop GitHub Comments

1reaction
wesselbcommented, Sep 9, 2022

I see! Hmm, this might be challenging. Dispatch currently heavily leverages types, and the type of a PyTree is somewhat troublesome. You’re right that jaxtyping offers a PyTree type, but that type seems to only perform instance checking rather than containing the recursive type definition that we would like. I’ll have to think about this! It agree that it would be super useful to support PyTrees.

0reactions
wesselbcommented, Oct 18, 2022

I think that it would be possible to convert the static number to a dynamic depth. However, perhaps the right solution here is to see if we can actually give PyTrees first-class support. I’ll soon be working on a 2.0 of Plum, which is where currently the restrictions derive from. I will put PyTree support on the list of desired improvements!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Speed up loop... · Discussion #9467 · google/jax - GitHub
fori_loop works? But I think there is a sequential dependency here. Maybe you just need to combine get_params&value_and_grad&update in one function and jit...
Read more >
What is the reason for Unable to JIT error - Okta Support
Upon SAML authentication against external IdP I am getting redirected to the login page and this error appears the log Authenticate user via...
Read more >
JIT unable to improve my JAX code: where am I wrong?
Here is a simple JAX code to show the Metropolis Algorithm in action to solve a 3 parameters bayesian regression pb. Running wo...
Read more >
Changelog - iMinuit
Minuit.mncontour used to fail if called twice in a row ... Updated tutorial about automatic differentiation, added comparison of numba.njit and jax.jit ......
Read more >
Designing modular inference engines: API for the HMC kernel
The logpdf can be JIT-compiled with JAX and used for batched inference. ... it fails to accomodate more complex schemes like NUTS.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found