NonConcreteBooleanIndexError on call to jnp.unique
See original GitHub issueI’m getting the following error:
/usr/local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _expand_bool_indices(idx, shape)
3875 if not type(abstract_i) is ConcreteArray:
3876 # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
-> 3877 raise errors.NonConcreteBooleanIndexError(abstract_i)
3878 elif _ndim(i) == 0:
3879 raise TypeError("JAX arrays do not support boolean scalar indices")
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[4897])
as the result of a call to jnp.unique
:
202[ ]() """
[203](file:///var/repos/pie_live/research/projections/pitchers/stuff_proj.py?line=202)
--> [204](file:///var/repos/pie_live/research/projections/pitchers/stuff_proj.py?line=203)
ages_pred = jnp.unique(age_idx)
where age_idx
is an int DeviceArray:
DeviceArray([24, 25, 26, ..., 6, 14, 10], dtype=int32)
Its not immediately clear why this error would be propagated in this context. Any ideas appreciated.
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
JAX Errors - JAX documentation - Read the Docs
func(jnp.arange(4), 0) Traceback (most recent call last): . ... NonConcreteBooleanIndexError: Array boolean indices must be concrete: ShapedArray(bool[10]).
Read more >is it possible to jit a function which uses jax.numpy.unique?
No, for the reasons you mention, there's currently no way to use jnp.unique on a non-static value. In similar cases, JAX sometimes adds ......
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks - don’t worry about the reproduction, I think I understand where it’s coming from now, and I can work on improving the error message. The root cause is attempting to JIT-compile
jnp.unique
, which returns an array with data-dependent shape. If you want to use this within JIT, you’ll need to statically specify thesize
argument tojnp.unique
.The code is part of a much bigger model that I can’t share, but I will try to create a smaller, reproducible example.