question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

lax.reduce: support closure-over Tracers in reduction function

See original GitHub issue

Hi,

I’m using jax.lax.reduce to create a (sort of) fingerprint function, here’s a simplified version of my code to reproduce the problem.

import jax
import jax.numpy as jnp

def fingerprint(array, axes=None):
  assert array.dtype in (jnp.int32, jnp.uint32, jnp.float32)
  array = jnp.asarray(array)
  if axes is None:
    axes = tuple(range(array.ndim))
  array = jax.lax.bitcast_convert_type(array, jnp.int32)
  magic = jnp.uint32(0x9e3779b9).astype(jnp.int32)

  def combine_fn(x, y):
    y = (y + magic + jax.lax.shift_left(x, x.dtype.type(6)) +
         jax.lax.shift_right_logical(x, x.dtype.type(2)))
    return jax.lax.bitwise_xor(x, y)

  return jax.lax.reduce(array, magic, combine_fn, axes)

When I use a jitted version of the function, I get an AssertionError in jax.lax.reduce.

x = jax.random.normal(jax.random.PRNGKey(0), (10, 8, 5))
print(fingerprint(x))  # prints: -1114046056
print(jax.jit(fingerprint)(x))  # raises exception

Raises:

AssertionError                            Traceback (most recent call last)
google3/third_party/py/jax/_src/lax/lax.py in _reduction_computation(c, jaxpr, consts, init_values, singleton)
   5284   axis_env = xla.AxisEnv(1, (), ())  # no parallel primitives inside reductions
   5285   subc = xla_bridge.make_computation_builder("reduction_computation")
-> 5286   assert len(consts) == 0, "Reduction computations cannot have constants"
   5287   args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
   5288   out_nodes = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)

AssertionError: Reduction computations cannot have constants

The input array is not a constant, but the initial_value is. So, I assume that the problem is there.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
jpuigcervercommented, Aug 4, 2021

@froystig Your solution worked like a charm 👍

0reactions
froystigcommented, Aug 3, 2021

This code used to work a few month ago, why can’t magic be closed over combine_fn?

A lot has changed over the past year that could have affected this, including our entire approach to staging out functions and what we hoist as a constant, so bisection would require some effort and might not be too enlightening. The current state is that we don’t support closure-captured values, but we probably want to.

Setting aside @jakevdp’s observation (which is worth looking into separately), a simple workaround is to inline the expression for magic into combine_fn rather than closing over the resulting value. It isn’t especially elegant, but this works for me:

def magic():
  return jnp.uint32(0x9e3779b9).astype(jnp.int32)

def combine_fn(x, y):
  y = (y + magic() + jax.lax.shift_left(x, x.dtype.type(6)) +
       jax.lax.shift_right_logical(x, x.dtype.type(2)))
  return jax.lax.bitwise_xor(x, y)

def fingerprint(array, axes=None):
  assert array.dtype in (jnp.int32, jnp.uint32, jnp.float32)
  array = jnp.asarray(array)
  if axes is None:
    axes = tuple(range(array.ndim))
  array = jax.lax.bitcast_convert_type(array, jnp.int32)
  return jax.lax.reduce(array, magic(), combine_fn, axes)

We should look into supporting closure. While we ponder that, does the above at least unblock you and get things compiling?

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.reduce - JAX documentation - Read the Docs
jax.lax.reduce# ... Wraps XLA's Reduce operator. init_values and computation together must form a monoid for correctness. That is init_values must be an identity ......
Read more >
VETTAM: A scheme for radiation hydrodynamics with adaptive ...
The FLD closure reduces the radiation transport to a parabolic diffusion equation, with a diffu- sion coefficient chosen to limit the photon ...
Read more >
hacat human keratinocytes: Topics by Science.gov
In HaCaT cells exposure to PCB153 significantly reduced telomerase activity ... Inhibitor studies further support a role for endocytosis during HSV-1 entry.
Read more >
Re-engineering Philosophy for Limited Beings: Piecewise ...
methods eliminate or analyze away what is being analyzed or reduced. ... pivotal assumptions that play a role in generating our philosophical.
Read more >
Abstract - Europe PMC
The role of urinary biomarkers (IL-6, TGF B, MCP-1, E-CADHERIN) for monitoring ... Robotic surgery more ergonomic thus reduces the chances of errors...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found