question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Specify casting rules and accepted input dtypes for reductions better

See original GitHub issue

Reductions were added in PR gh-17, based on discussion in gh-10. There was quite a bit of discussion in calls as well around reductions (e.g., which ones to support, returning 0-D arrays and not scalars, naming) but not about casting rules and accepted input dtypes. It turns out that this is pretty inconsistent between libraries. Here’s a script that compares sum, std and prod:

import numpy as np
import dask.array as da
import torch
import tensorflow as tf
import jax.numpy as jnp
import mxnet
try:
    import cupy as cp
except ImportError:
    # CuPy is GPU-only, so may not be available
    cp = None


def ones(mod, shape):
    # Create a (3, 2)-shaped array of int8 1's
    if mod in (da, mxnet):
        x = mod.ones(shape, dtype=np.int8)  # MXNet doesn't have dtype literals
    else:
        x = mod.ones(shape, dtype=mod.int8)

    return x

def sum(mod, x):
    if mod == tf:
        y = tf.math.reduce_sum(x)
    else:
        y = mod.sum(x)

    return y


def std(mod, x):
    if mod == tf:
        y = tf.math.reduce_std(x)
    elif mod == mxnet:
        y = mxnet.std(x)
    else:
        y = mod.std(x)

    return y


def prod(mod, x):
    if mod == tf:
        y = tf.math.reduce_prod(x)
    else:
        y = mod.prod(x)

    return y


libraries = {
    'numpy': np,
    'pytorch': torch,
    'mxnet': mxnet.np,
    'dask': da,
    'tensorflow': tf,
    'jax': jnp,
}

if cp is not None:
    libraries['cupy'] = cp

results = libraries.copy()

# A separate call to get rid of TF and JAX noise:
shape = (3, 2)
_ = sum(tf, ones(tf, shape))
_ = sum(jnp, ones(jnp, shape))


print("\nsum(int8_array)\n" + "-"*15)
for name, mod in libraries.items():
    dtype = sum(mod, ones(mod, shape)).dtype
    print(f'{name}: {dtype}')

print("\nstd(int8_array)\n" + "-"*15)
for name, mod in libraries.items():
    try:
        dtype = std(mod, ones(mod, shape)).dtype
        print(f'{name}: {dtype}')
    except Exception as e:
        print(f'{name}: {repr(e)}')

print("\nprod(int8_array)\n" + "-"*16)
for name, mod in libraries.items():
    try:
        dtype = prod(mod, ones(mod, shape)).dtype
        print(f'{name}: {dtype}')
    except Exception as e:
        print(f'{name}: {repr(e)}')

And the result of that:

sum(int8_array)
---------------
numpy: int64
pytorch: torch.int64
mxnet: int8
dask: int64
tensorflow: <dtype: 'int8'>
jax: int32
cupy: int64

std(int8_array)
---------------
numpy: float64
pytorch: RuntimeError('std only supports floating-point dtypes')
mxnet: int8
dask: float64
tensorflow: TypeError('Input must be either real or complex')
jax: float32
cupy: float64

prod(int8_array)
----------------
numpy: int64
pytorch: torch.int64
mxnet: int8
dask: int64
tensorflow: <dtype: 'int8'>
jax: int32
cupy: int64

Conclusions

For sum(int8) and prod(int8) there appear to be two options:

  1. Return the default integer dtype (i.e. upcast the input). This is what NumPy, PyTorch, JAX and CuPy do (and Dask behavior is dictated by the underlying NumPy/CuPy arrays).
  2. Keep the input dtype. This is what TensorFlow and MXNet do. It is also what the spec currently says.

The TensorFlow docs do note this as the one inconsistency with NumPy: https://www.tensorflow.org/api_docs/python/tf/math/reduce_sum says “Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to int64 while tensorflow returns the same dtype as the input.”

The MXNet docs at https://mxnet.apache.org/versions/master/api/python/docs/api/np/generated/mxnet.np.sum.html#mxnet-np-sum do not clearly say that this is expected, even though those docs do have a list of differences with NumPy (@szha thoughts on this?).

For std(int8) there appear to be three options:

  1. Return the default floating-point dtype
  2. Raise an exception. For TensorFlow this is consistent with its design, because it doesn’t do int-to-float casting. Why PyTorch raises is unclear, probably for historical reasons (it has int-to-float casting now, but didn’t used to have it).
  3. Keep the input dtype. This is what MXNet does. It’s a consistent design rule, but clearly doesn’t make too much sense - I’d expect this one to be a mistake.

This is all quite inconsistent, and needs to be considered more carefully for all reductions and dtypes.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:19 (16 by maintainers)

github_iconTop GitHub Comments

2reactions
brycelelbachcommented, Jun 17, 2021

The new C++ guidance is to infer the intermediate type from the operator - see P2322.

1reaction
asmeurercommented, Jun 21, 2021

This is also relevant to the trace function in the linear algebra extension (there are other functions in linear algebra that accept integer inputs, but they all take two arguments, so they use normal type promotion).

Read more comments on GitHub >

github_iconTop Results From Across the Web

Consider automatic casting rules for promoting dtypes #21491
NumPy has well defined casting rules for converting between dtypes. If an operation cannot be performed natively on the given dtypes (e.g., np....
Read more >
tf.keras.mixed_precision.Policy | TensorFlow v2.11.0
A dtype policy for a Keras layer. ... A layer casts its inputs to its compute dtype. This causes the layer's computations and...
Read more >
What's new in 1.5.0 (September 19, 2022) - Pandas
With Pyarrow installed, users can now create pandas objects that are backed by a pyarrow.ChunkedArray and pyarrow.DataType . The dtype argument can accept...
Read more >
Using The Pandas Category Data Type
Introduction to pandas categorical data type and how to use it.
Read more >
torch.linalg.matrix_norm — PyTorch 1.13 documentation
Support input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices: the norm will be computed over the dimensions specified...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found