lax.dynamic_slice inside jit
See original GitHub issueShould this work?
import jax
import jax.numpy as np
@jax.jit
def sum_first_k(a, k):
return np.sum(lax.dynamic_slice(a, (0,), (k,)))
sum_first_k(np.arange(3.0), 2)
Here’s the traceback I get:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-167-645715d2be42> in <module>()
----> 1 sum_first_k(np.arange(3.0), 2)
13 frames
/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
121 _check_args(args_flat)
122 flat_fun, out_tree = flatten_fun_leafout(f, in_tree)
--> 123 out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
124 return out if out_tree() is leaf else tree_unflatten(out_tree(), out)
125
/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, f, *args, **params)
661 if top_trace is None:
662 with new_sublevel():
--> 663 ans = primitive.impl(f, *args, **params)
664 else:
665 tracers = map(top_trace.full_raise, args)
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_call_impl(fun, *args, **params)
604 def xla_call_impl(fun, *args, **params):
605 device_values = FLAGS.jax_device_values and params.pop('device_values')
--> 606 compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
607 try:
608 return compiled_fun(*args)
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in memoized_fun(f, *args)
206 if len(cache) > max_size:
207 cache.popitem(last=False)
--> 208 ans = call(f, *args)
209 cache[key] = (ans, f)
210 return ans
/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in xla_callable(fun, device_values, *abstract_args)
617 pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
618 with core.new_master(pe.JaxprTrace, True) as master:
--> 619 jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
620 assert not env # no subtraces here (though cond might eventually need them)
621 compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
/usr/local/lib/python3.6/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
145
146 del gen
--> 147 ans = self.f(*args, **dict(self.params, **kwargs))
148 del args
149 while stack:
<ipython-input-165-9a17ef1ee2d8> in sum_first_k(a, k)
1 @jax.jit
2 def sum_first_k(a, k):
----> 3 return np.sum(lax.dynamic_slice(a, (0,), (k,)))
/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in dynamic_slice(operand, start_indices, slice_sizes)
607 return dynamic_slice_p.bind(
608 operand, start_indices, slice_sizes=tuple(slice_sizes),
--> 609 operand_shape=operand.shape)
610
611 def dynamic_update_slice(operand, update, start_indices):
/usr/local/lib/python3.6/dist-packages/jax/core.py in bind(self, *args, **kwargs)
145
146 tracers = map(top_trace.full_raise, args)
--> 147 out_tracer = top_trace.process_primitive(self, tracers, kwargs)
148 return full_lower(out_tracer)
149
/usr/local/lib/python3.6/dist-packages/jax/interpreters/partial_eval.py in process_primitive(self, primitive, tracers, params)
100 tracers = map(self.instantiate_const, tracers)
101 avals = [t.aval for t in tracers]
--> 102 out_aval = primitive.abstract_eval(*avals, **params)
103 eqn = JaxprEqn(tracers, None, primitive, (), False, False, params)
104 return JaxprTracer(self, PartialVal((out_aval, unit)), eqn)
/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in standard_abstract_eval(shape_rule, dtype_rule, *args, **kwargs)
1405 return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
1406 elif least_specialized is ShapedArray:
-> 1407 return ShapedArray(shape_rule(*args, **kwargs), dtype_rule(*args, **kwargs))
1408 elif least_specialized is UnshapedArray:
1409 return UnshapedArray(dtype_rule(*args, **kwargs))
/usr/local/lib/python3.6/dist-packages/jax/lax/lax.py in _dynamic_slice_shape_rule(operand, start_indices, slice_sizes, operand_shape)
2608 "start_indices, got start_inidices length {} and slice_sizes {}.")
2609 raise TypeError(msg.format(len(start_indices), slice_sizes))
-> 2610 if not onp.all(onp.less_equal(slice_sizes, operand.shape)):
2611 msg = ("slice slice_sizes must be less than or equal to operand shape, "
2612 "got slice_sizes {} for operand shape {}.")
/usr/local/lib/python3.6/dist-packages/jax/core.py in __bool__(self)
340 def __getitem__(self, idx): return self.aval._getitem(self, idx)
341 def __nonzero__(self): return self.aval._nonzero(self)
--> 342 def __bool__(self): return self.aval._bool(self)
343 def __float__(self): return self.aval._float(self)
344 def __int__(self): return self.aval._int(self)
/usr/local/lib/python3.6/dist-packages/jax/abstract_arrays.py in error(self, *args)
36 def concretization_function_error(fun):
37 def error(self, *args):
---> 38 raise TypeError(concretization_err_msg(fun))
39 return error
40
TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
I know XLA can’t have variable sized outputs, but here I’m summing the outputs, so in principle that shouldn’t be an issue.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:9 (6 by maintainers)
Top Results From Across the Web
jax.lax.dynamic_slice - JAX documentation - Read the Docs
Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have ... Here is a simple two-dimensional...
Read more >JAX Apply function only on slice of array under jit
First, the slices are producing dynamically shaped arrays (not allowed in jitted code). Second, unlike numpy arrays, JAX arrays are immutable ( ...
Read more >Common Gotchas in JAX - Colaboratory - Google Colab
inside jit 'd code and lax.while_loop or lax.fori_loop the size of slices can't be functions of argument values but only functions of argument...
Read more >How to use the jax.lax function in jax - Snyk
To help you get started, we've selected a few jax.lax examples, based on popular ways it is used in public projects.
Read more >python/google/jax/jax/_src/lax/slicing.py Example - Program Talk
... arrays inside JIT must have statically known size). Returns: An array containing the slice. Examples: Here is a simple two-dimensional dynamic slice: ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
It would be nice if the error message said something like this, rather than sending me down a rabbit hole. What actually happened is that I first tried using indexing like a[:k], which generated an error encouraging me to try lax.dyanmic_slice.
On Tue, Jul 9, 2019 at 9:49 PM Matthew Johnson notifications@github.com wrote:
I think it’s worth adding that
slice_sizes
needs to be static to thedynamic_slice()
docstring. I can send in a PR if that sounds good, WDYT?I ran into the same issue as shoyer@ above, where I want dynamic
slice_sizes()
and first tried indexinga[:k]
, then was told to usedynamic_slice()
, got this error message, poked around a bit, and then ended up here.