question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.ops.index_update of int array silently promotes nan to large integer

See original GitHub issue

jax.ops.index_update of int array silently promotes nan to a large integer (i.e. machine capacity for 32 bit int). Is this desired behavior ? The analogous op in numpy throws a ValueError instead.

import jax
import jax.numpy as np
import numpy as onp

a = np.array([0, 47, 0])
b = a.astype(float)

print(jax.ops.index_update(a, a == 0, np.nan)) # [-2147483648          47 -2147483648]
print(jax.ops.index_update(b, b == 0, np.nan)) # [nan 47. nan]

oa = onp.array(a)
oa[oa == 0] = np.nan # ValueError: cannot convert float NaN to integer

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
jakevdpcommented, Jun 29, 2022

As of #10924, this behavior emits a warning, and in a future release it will be an error:

>>> import jax.numpy as jnp                                                                                                                                                                                                                                                                                 
>>> jnp.zeros(10, dtype=int).at[3].set(4.5)                                                                                                                                                                                                                                                                 
jax/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
DeviceArray([0, 0, 0, 4, 0, 0, 0, 0, 0, 0], dtype=int32)

Same with the original issue:

>>> jnp.array([0, 47, 0]).at[0].set(jnp.nan)                                                                                                                                                                                                                                                                
/jax/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "
DeviceArray([ 0, 47,  0], dtype=int32)
0reactions
hawkinspcommented, May 11, 2021

We were recently discussing a similar issue:

In [5]: x = jnp.zeros((10), np.int32)

In [6]: x.at[3].set(4.5)
Out[6]: DeviceArray([0, 0, 0, 4, 0, 0, 0, 0, 0, 0], dtype=int32)

This is surprising, and we think it would be better to report an error rather than silently converting from float to int, which is the same problem as in this issue.

Read more comments on GitHub >

github_iconTop Results From Across the Web

The Sharp Bits — JAX documentation
Update the global # Subsequent runs may silently use the cached value of the ... When the indexing operation is an array index...
Read more >
Why converting np.nan to int results in huge number?
Converting floating-point Nan to an integer type is undefined behavior, as far as I know. The number: -9223372036854775808.
Read more >
symjax - arXiv
JAX is a python interface that provides a Numpy-like software on top of XLA and providing just-in-time compilation a well as advanced automatic ......
Read more >
Nullable integer data type - Pandas
Because NaN is a float, this forces an array of integers with any missing values to become floating point. In some cases, this...
Read more >
Determine which array elements are NaN - MATLAB isnan
If A contains complex numbers, isnan(A) contains 1 for elements with either real or imaginary part is NaN , and 0 for elements...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found