question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unraveling an array with dtype bool results in pytree arrays with dtype float32

See original GitHub issue

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
x = jnp.arange(10, dtype=jnp.float32)
x_flat, unravel = ravel_pytree(x)
y = x_flat < 5.3
print(y.dtype)          # => <dtype: 'bool'>
print(unravel(y).dtype) # => <dtype: 'float32'>
  • If applicable, include full error messages/tracebacks. n/a

I’m running Python 3.9 with jax 0.2.19 and jaxlib 0.1.70.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
samuelacommented, Dec 14, 2021

Hey @mattjj, thanks for explaining and putting together https://github.com/google/jax/pull/8951! LGTM!

1reaction
mattjjcommented, Dec 14, 2021

Sorry @samuela, I let this one slip through the cracks somehow!

The generated unravel function just isn’t dtype-polymorphic; it produces an output with type (i.e. pytree structure, array shapes, and array dtypes) equal to that of the provided input. Perhaps it should’ve even raised an error when given a bool-dtype input…

One reason for that is to handle the case when there are mixed dtypes in the input pytree. Since the raveled array needs to have a uniform dtype, we promote the dtypes of the inputs, and when unraveling we cast back to the input’s dtypes. But that means the output dtype doesn’t determine the input dtypes in general, and hence it’s not clear how to make an unravel which is polymorphic in the way you expected.

I’m not sure how to reconcile these two behaviors (supporting multiple distinct input dtypes, and producing a dtype-polymorphic unravel) in one function, without making the behavior complicated. We could just special case it to be “if all the input dtypes are equal, then produce a polymorphic unravel”. But that might be making a simple function kind of unpredictable.

Since the function is small, would it make sense just to have your own implementation? (Or, in the intervening time, did you find other solutions?)

Here’s an implementation which has the behavior you expected:

import jax
from jax.tree_util import tree_flatten, tree_unflatten
from jax.util import unzip2
import jax.numpy as jnp
import numpy as np

def ravel_pytree(pytree):
  leaves, treedef = tree_flatten(pytree)
  flat, unravel_list = _ravel_list(leaves)
  unravel_pytree = lambda flat: tree_unflatten(treedef, unravel_list(flat))
  return flat, unravel_pytree

def _ravel_list(lst):
  if not lst: return jnp.array([], jnp.float32), lambda _: []

  from_dtypes = [jax.dtypes.result_type(l) for l in lst]
  to_dtype = jax.dtypes.result_type(*from_dtypes)
  if not all(from_dtype == to_dtype for from_dtype in from_dtypes):
    raise Exception
  del from_dtypes, to_dtype

  sizes, shapes = unzip2((jnp.size(x), jnp.shape(x)) for x in lst)
  indices = np.cumsum(sizes)

  def unravel(arr):
    chunks = jnp.split(arr, indices[:-1])
    return [chunk.reshape(shape) for chunk, shape in zip(chunks, shapes)]

  raveled = jnp.concatenate([jnp.ravel(e) for e in lst])
  return raveled, unravel


x = jnp.arange(10, dtype=jnp.float32)
x_flat, unravel = ravel_pytree(x)
y = x_flat < 5.3
print(y.dtype)          # => <dtype: 'bool'>
print(unravel(y).dtype) # => <dtype: 'bool'>

Since I’m in an issue-closing mood, and since this is an old one, I’m going to somewhat preemptively close this issue. Let us know if we should reopen and continue the discussion!

Read more comments on GitHub >

github_iconTop Results From Across the Web

9. Numpy: Boolean Indexing | Numerical Programming
It is called fancy indexing, if arrays are indexed by using boolean or integer arrays (masks). The result will be a copy and...
Read more >
How do I create a numpy array of all True or all False?
The answer: numpy.full((2, 2), True). Explanation: numpy creates arrays of all ones or all zeros very easily: e.g. numpy.ones((2, ...
Read more >
Working with Pytrees - JAX documentation - Read the Docs
Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures....
Read more >
2.1. Numba Types
Create an array type. dtype should be a Numba type. ndim is the number of dimensions of the array (a positive integer). layout...
Read more >
numpy.all — NumPy v1.24 Manual
Alternate output array in which to place the result. It must have the same shape as the expected output and its type is...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found