Specify casting rules and accepted input dtypes for reductions better
See original GitHub issueReductions were added in PR gh-17, based on discussion in gh-10. There was quite a bit of discussion in calls as well around reductions (e.g., which ones to support, returning 0-D arrays and not scalars, naming) but not about casting rules and accepted input dtypes. It turns out that this is pretty inconsistent between libraries. Here’s a script that compares sum
, std
and prod
:
import numpy as np
import dask.array as da
import torch
import tensorflow as tf
import jax.numpy as jnp
import mxnet
try:
import cupy as cp
except ImportError:
# CuPy is GPU-only, so may not be available
cp = None
def ones(mod, shape):
# Create a (3, 2)-shaped array of int8 1's
if mod in (da, mxnet):
x = mod.ones(shape, dtype=np.int8) # MXNet doesn't have dtype literals
else:
x = mod.ones(shape, dtype=mod.int8)
return x
def sum(mod, x):
if mod == tf:
y = tf.math.reduce_sum(x)
else:
y = mod.sum(x)
return y
def std(mod, x):
if mod == tf:
y = tf.math.reduce_std(x)
elif mod == mxnet:
y = mxnet.std(x)
else:
y = mod.std(x)
return y
def prod(mod, x):
if mod == tf:
y = tf.math.reduce_prod(x)
else:
y = mod.prod(x)
return y
libraries = {
'numpy': np,
'pytorch': torch,
'mxnet': mxnet.np,
'dask': da,
'tensorflow': tf,
'jax': jnp,
}
if cp is not None:
libraries['cupy'] = cp
results = libraries.copy()
# A separate call to get rid of TF and JAX noise:
shape = (3, 2)
_ = sum(tf, ones(tf, shape))
_ = sum(jnp, ones(jnp, shape))
print("\nsum(int8_array)\n" + "-"*15)
for name, mod in libraries.items():
dtype = sum(mod, ones(mod, shape)).dtype
print(f'{name}: {dtype}')
print("\nstd(int8_array)\n" + "-"*15)
for name, mod in libraries.items():
try:
dtype = std(mod, ones(mod, shape)).dtype
print(f'{name}: {dtype}')
except Exception as e:
print(f'{name}: {repr(e)}')
print("\nprod(int8_array)\n" + "-"*16)
for name, mod in libraries.items():
try:
dtype = prod(mod, ones(mod, shape)).dtype
print(f'{name}: {dtype}')
except Exception as e:
print(f'{name}: {repr(e)}')
And the result of that:
sum(int8_array)
---------------
numpy: int64
pytorch: torch.int64
mxnet: int8
dask: int64
tensorflow: <dtype: 'int8'>
jax: int32
cupy: int64
std(int8_array)
---------------
numpy: float64
pytorch: RuntimeError('std only supports floating-point dtypes')
mxnet: int8
dask: float64
tensorflow: TypeError('Input must be either real or complex')
jax: float32
cupy: float64
prod(int8_array)
----------------
numpy: int64
pytorch: torch.int64
mxnet: int8
dask: int64
tensorflow: <dtype: 'int8'>
jax: int32
cupy: int64
Conclusions
For sum(int8)
and prod(int8)
there appear to be two options:
- Return the default integer dtype (i.e. upcast the input). This is what NumPy, PyTorch, JAX and CuPy do (and Dask behavior is dictated by the underlying NumPy/CuPy arrays).
- Keep the input dtype. This is what TensorFlow and MXNet do. It is also what the spec currently says.
The TensorFlow docs do note this as the one inconsistency with NumPy: https://www.tensorflow.org/api_docs/python/tf/math/reduce_sum says “Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to int64 while tensorflow returns the same dtype as the input.”
The MXNet docs at https://mxnet.apache.org/versions/master/api/python/docs/api/np/generated/mxnet.np.sum.html#mxnet-np-sum do not clearly say that this is expected, even though those docs do have a list of differences with NumPy (@szha thoughts on this?).
For std(int8)
there appear to be three options:
- Return the default floating-point dtype
- Raise an exception. For TensorFlow this is consistent with its design, because it doesn’t do int-to-float casting. Why PyTorch raises is unclear, probably for historical reasons (it has int-to-float casting now, but didn’t used to have it).
- Keep the input dtype. This is what MXNet does. It’s a consistent design rule, but clearly doesn’t make too much sense - I’d expect this one to be a mistake.
This is all quite inconsistent, and needs to be considered more carefully for all reductions and dtypes.
Issue Analytics
- State:
- Created 2 years ago
- Comments:19 (16 by maintainers)
Top GitHub Comments
The new C++ guidance is to infer the intermediate type from the operator - see P2322.
This is also relevant to the
trace
function in the linear algebra extension (there are other functions in linear algebra that accept integer inputs, but they all take two arguments, so they use normal type promotion).