question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Reductions can quietly muck with loop in/out types

See original GitHub issue

Simple example:

import jax
jax.config.update('jax_enable_x64', True)
print(jax.config.values)

from jax import lax, numpy as jnp

def f(carry, y):
  return carry + jnp.prod(y, axis=-1), None

lax.scan(f, init=jnp.zeros([2], dtype=jnp.int32), xs=jnp.ones([3, 2, 0],dtype=jnp.int32))[0]
TypeError: scan carry output and input must have identical types, got
ShapedArray(int64[2])
and
ShapedArray(int32[2]).

The issue seems to be that prod without the dtype arg upcasts (agreeing w/ the numpy spec) to the platform int precision, which jax apparently interprets as 64-bit with 'jax_enable_x64' set to True.

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:8 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
brianwa84commented, May 21, 2020

Also bear in mind the accumulator could use a higher precision and still return a downcast matching dtype.

On Thu, May 21, 2020, 12:58 AM Stephan Hoyer notifications@github.com wrote:

I’m sure that’s the idea, but if float32 overflow with multiplication is a serious worry, you should probably be worried about float64 overflow, too. It’s only an 8x larger range for the exponent. I just don’t think it makes a big enough difference to be worth the inconsistency.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/3154#issuecomment-631880484, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFJFSI6KOUHM5OXNM2YZ6HDRSSYI7ANCNFSM4NFMN6EA .

0reactions
jakevdpcommented, Feb 3, 2021

Revisiting this… it looks like this is only an issue for integer dtypes:

In [1]: import jax.numpy as jnp
   ...: from jax.test_util import dtypes
   ...: from jax import config; config.update('jax_enable_x64', True)
   ...: for dtype in dtypes.all:
   ...:     print(dtype.__name__, "->", jnp.zeros(2, dtype).sum().dtype)
   ...:     
bfloat16 -> bfloat16
float16 -> float16
float32 -> float32
float64 -> float64
int8 -> int64
int16 -> int64
int32 -> int64
int64 -> int64
uint8 -> uint64
uint16 -> uint64
uint32 -> uint64
uint64 -> uint64
complex64 -> complex64
complex128 -> complex128
bool_ -> int64

I think it’s reasonable to make the result type match the input by default: you could always recover the current behavior via explicit casts, but there’s currently no way to make the reduction preserve the input dtype.

Read more comments on GitHub >

github_iconTop Results From Across the Web

rollup.js
rollup.config.js // can be an array (for multiple inputs) export default ... when the primary purpose of the imported code is to muck...
Read more >
Recommended minimum requirements for plumbing - GovInfo
The factor of safety in a plumbing system and determination of the upper limit of service for a 3-inch soil stack. 146. Establishment...
Read more >
The 5 Best Pressure Washers of 2022 | Reviews by Wirecutter
The Ryobi RY142300 2300 PSI Brushless Electric Pressure Washer is the best pressure washer we found after 70 hours of research and testing....
Read more >
Dictionary of Terminology - Nemaplex
Aberrant Deviating from the usual type or form. ... Abyssal Fauna Organisms dwelling at oceanic depths below 6,000 feet, quiet water, complete darkness....
Read more >
Glossaries of BLM surveying and mapping terms
The definitions are not meant to conflict with those in other glossaries, but since the glossary is for. BLM cadastral personnel, some terms...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found