question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

lax.dynamic_slice inside jit

See original GitHub issue

Should this work?

import jax
import jax.numpy as np

@jax.jit
def sum_first_k(a, k):
  return np.sum(lax.dynamic_slice(a, (0,), (k,)))

sum_first_k(np.arange(3.0), 2)

Here’s the traceback I get:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-167-645715d2be42> in <module>()
----> 1 sum_first_k(np.arange(3.0), 2)

13 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
    121     _check_args(args_flat)
    122     flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123     out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
    124     return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
    125 

/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
    661   if top_trace is None:
    662     with new_sublevel():
--> 663       ans = primitive.impl(f, *args, **params)
    664   else:
    665     tracers = map(top_trace.full_raise, args)

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
    604 def xla_call_impl(fun, *args, **params):
    605   device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 606   compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
    607   try:
    608     return compiled_fun(*args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(f, *args)
    206       if len(cache) > max_size:
    207         cache.popitem(last=False)
--> 208       ans = call(f, *args)
    209       cache[key] = (ans, f)
    210     return ans

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
    617   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    618   with core.new_master(pe.JaxprTrace, True) as master:
--> 619     jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    620     assert not env  # no subtraces here (though cond might eventually need them)
    621     compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)

/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    145 
    146     del gen
--> 147     ans = self.f(*args, **dict(self.params, **kwargs))
    148     del args
    149     while stack:

<ipython-input-165-9a17ef1ee2d8> in sum_first_k(a, k)
      1 @jax.jit
      2 def sum_first_k(a, k):
----> 3   return np.sum(lax.dynamic_slice(a, (0,), (k,)))

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in dynamic_slice(operand, start_indices, slice_sizes)
    607   return dynamic_slice_p.bind(
    608       operand, start_indices, slice_sizes=tuple(slice_sizes),
--> 609       operand_shape=operand.shape)
    610 
    611 def dynamic_update_slice(operand, update, start_indices):

/usr/local/lib/python3.6/dist-packages/jax/core.py in bind(self, *args, **kwargs)
    145 
    146     tracers = map(top_trace.full_raise, args)
--> 147     out_tracer = top_trace.process_primitive(self, tracers, kwargs)
    148     return full_lower(out_tracer)
    149 

/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
    100       tracers = map(self.instantiate_const, tracers)
    101       avals = [t.aval for t in tracers]
--> 102       out_aval = primitive.abstract_eval(*avals, **params)
    103       eqn = JaxprEqn(tracers, None, primitive, (), False, False, params)
    104       return JaxprTracer(self, PartialVal((out_aval, unit)), eqn)

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs)
   1405     return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
   1406   elif least_specialized is ShapedArray:
-> 1407     return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
   1408   elif least_specialized is UnshapedArray:
   1409     return UnshapedArray(dtype_rule(*args, **kwargs))

/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in _dynamic_slice_shape_rule(operand, start_indices, slice_sizes, operand_shape)
   2608            "start_indices, got start_inidices length {} and slice_sizes {}.")
   2609     raise TypeError(msg.format(len(start_indices), slice_sizes))
-> 2610   if not onp.all(onp.less_equal(slice_sizes, operand.shape)):
   2611     msg = ("slice slice_sizes must be less than or equal to operand shape, "
   2612            "got slice_sizes {} for operand shape {}.")

/usr/local/lib/python3.6/dist-packages/jax/core.py in __bool__(self)
    340   def __getitem__(self, idx): return self.aval._getitem(self, idx)
    341   def __nonzero__(self): return self.aval._nonzero(self)
--> 342   def __bool__(self): return self.aval._bool(self)
    343   def __float__(self): return self.aval._float(self)
    344   def __int__(self): return self.aval._int(self)

/usr/local/lib/python3.6/dist-packages/jax/abstract_arrays.py in error(self, *args)
     36 def concretization_function_error(fun):
     37   def error(self, *args):
---> 38     raise TypeError(concretization_err_msg(fun))
     39   return error
     40 

TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

I know XLA can’t have variable sized outputs, but here I’m summing the outputs, so in principle that shouldn’t be an issue.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:9 (6 by maintainers)

github_iconTop GitHub Comments

6reactions
shoyercommented, Jul 10, 2019

It would be nice if the error message said something like this, rather than sending me down a rabbit hole. What actually happened is that I first tried using indexing like a[:k], which generated an error encouraging me to try lax.dyanmic_slice.

On Tue, Jul 9, 2019 at 9:49 PM Matthew Johnson notifications@github.com wrote:

No, it shouldn’t work: actually it’s not just that XLA (and JAX’s jit, which is what’s actually raising the error here for tracing reasons) require fixed output shapes, but all the shapes of the intermediates need to be fixed too. So summing the output doesn’t help; that lax.dynamic_slice alone is a problem.

Here are a two alternatives, both of which you probably know about:

from future import print_functionfrom functools import partial import jaximport jax.numpy as np @partial(jax.jit, static_argnums=(1,))def sum_first_k(a, k): return np.sum(jax.lax.dynamic_slice(a, (0,), (k,))) print(sum_first_k(np.arange(3.0), 2))

@jax.jitdef sum_first_k(a, k): n = len(a) return np.sum(np.where(np.arange(n) < k, a, 0)) print(sum_first_k(np.arange(3.0), 2))

The first is a way of solving the problem with recompilation. The second is a way to solve it with masking, for which XLA can still generate very efficient code by fusing the selection into the reduction rather than round-tripping several arrays to memory. A third strategy is to use a loop construct.

WDYT?

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/1007?email_source=notifications&email_token=AAJJFVXUEEPZYGHSO7K6WE3P6VS7HA5CNFSM4H7I6LE2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODZSI4PI#issuecomment-509906493, or mute the thread https://github.com/notifications/unsubscribe-auth/AAJJFVWEJNLZMCW3JTWJYJTP6VS7HANCNFSM4H7I6LEQ .

4reactions
juesatocommented, May 12, 2020

I think it’s worth adding that slice_sizes needs to be static to the dynamic_slice() docstring. I can send in a PR if that sounds good, WDYT?

I ran into the same issue as shoyer@ above, where I want dynamic slice_sizes() and first tried indexing a[:k], then was told to use dynamic_slice(), got this error message, poked around a bit, and then ended up here.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.dynamic_slice - JAX documentation - Read the Docs
Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have ... Here is a simple two-dimensional...
Read more >
JAX Apply function only on slice of array under jit
First, the slices are producing dynamically shaped arrays (not allowed in jitted code). Second, unlike numpy arrays, JAX arrays are immutable ( ...
Read more >
Common Gotchas in JAX - Colaboratory - Google Colab
inside jit 'd code and lax.while_loop or lax.fori_loop the size of slices can't be functions of argument values but only functions of argument...
Read more >
How to use the jax.lax function in jax - Snyk
To help you get started, we've selected a few jax.lax examples, based on popular ways it is used in public projects.
Read more >
python/google/jax/jax/_src/lax/slicing.py Example - Program Talk
... arrays inside JIT must have statically known size). Returns: An array containing the slice. Examples: Here is a simple two-dimensional dynamic slice: ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found