Unraveling an array with dtype bool results in pytree arrays with dtype float32
See original GitHub issuePlease:
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
x = jnp.arange(10, dtype=jnp.float32)
x_flat, unravel = ravel_pytree(x)
y = x_flat < 5.3
print(y.dtype) # => <dtype: 'bool'>
print(unravel(y).dtype) # => <dtype: 'float32'>
- If applicable, include full error messages/tracebacks. n/a
I’m running Python 3.9 with jax 0.2.19 and jaxlib 0.1.70.
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (4 by maintainers)
Top Results From Across the Web
9. Numpy: Boolean Indexing | Numerical Programming
It is called fancy indexing, if arrays are indexed by using boolean or integer arrays (masks). The result will be a copy and...
Read more >How do I create a numpy array of all True or all False?
The answer: numpy.full((2, 2), True). Explanation: numpy creates arrays of all ones or all zeros very easily: e.g. numpy.ones((2, ...
Read more >Working with Pytrees - JAX documentation - Read the Docs
Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures....
Read more >2.1. Numba Types
Create an array type. dtype should be a Numba type. ndim is the number of dimensions of the array (a positive integer). layout...
Read more >numpy.all — NumPy v1.24 Manual
Alternate output array in which to place the result. It must have the same shape as the expected output and its type is...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hey @mattjj, thanks for explaining and putting together https://github.com/google/jax/pull/8951! LGTM!
Sorry @samuela, I let this one slip through the cracks somehow!
The generated
unravel
function just isn’t dtype-polymorphic; it produces an output with type (i.e. pytree structure, array shapes, and array dtypes) equal to that of the provided input. Perhaps it should’ve even raised an error when given a bool-dtype input…One reason for that is to handle the case when there are mixed dtypes in the input pytree. Since the raveled array needs to have a uniform dtype, we promote the dtypes of the inputs, and when unraveling we cast back to the input’s dtypes. But that means the output dtype doesn’t determine the input dtypes in general, and hence it’s not clear how to make an
unravel
which is polymorphic in the way you expected.I’m not sure how to reconcile these two behaviors (supporting multiple distinct input dtypes, and producing a dtype-polymorphic
unravel
) in one function, without making the behavior complicated. We could just special case it to be “if all the input dtypes are equal, then produce a polymorphicunravel
”. But that might be making a simple function kind of unpredictable.Since the function is small, would it make sense just to have your own implementation? (Or, in the intervening time, did you find other solutions?)
Here’s an implementation which has the behavior you expected:
Since I’m in an issue-closing mood, and since this is an old one, I’m going to somewhat preemptively close this issue. Let us know if we should reopen and continue the discussion!