lax.reduce: support closure-over Tracers in reduction function
See original GitHub issueHi,
I’m using jax.lax.reduce to create a (sort of) fingerprint function, here’s a simplified version of my code to reproduce the problem.
import jax
import jax.numpy as jnp
def fingerprint(array, axes=None):
assert array.dtype in (jnp.int32, jnp.uint32, jnp.float32)
array = jnp.asarray(array)
if axes is None:
axes = tuple(range(array.ndim))
array = jax.lax.bitcast_convert_type(array, jnp.int32)
magic = jnp.uint32(0x9e3779b9).astype(jnp.int32)
def combine_fn(x, y):
y = (y + magic + jax.lax.shift_left(x, x.dtype.type(6)) +
jax.lax.shift_right_logical(x, x.dtype.type(2)))
return jax.lax.bitwise_xor(x, y)
return jax.lax.reduce(array, magic, combine_fn, axes)
When I use a jitted version of the function, I get an AssertionError in jax.lax.reduce.
x = jax.random.normal(jax.random.PRNGKey(0), (10, 8, 5))
print(fingerprint(x)) # prints: -1114046056
print(jax.jit(fingerprint)(x)) # raises exception
Raises:
AssertionError Traceback (most recent call last)
google3/third_party/py/jax/_src/lax/lax.py in _reduction_computation(c, jaxpr, consts, init_values, singleton)
5284 axis_env = xla.AxisEnv(1, (), ()) # no parallel primitives inside reductions
5285 subc = xla_bridge.make_computation_builder("reduction_computation")
-> 5286 assert len(consts) == 0, "Reduction computations cannot have constants"
5287 args = [xb.parameter(subc, i, shape) for i, shape in enumerate(shapes)]
5288 out_nodes = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
AssertionError: Reduction computations cannot have constants
The input array is not a constant, but the initial_value is. So, I assume that the problem is there.
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (4 by maintainers)
Top Results From Across the Web
jax.lax.reduce - JAX documentation - Read the Docs
jax.lax.reduce# ... Wraps XLA's Reduce operator. init_values and computation together must form a monoid for correctness. That is init_values must be an identity ......
Read more >VETTAM: A scheme for radiation hydrodynamics with adaptive ...
The FLD closure reduces the radiation transport to a parabolic diffusion equation, with a diffu- sion coefficient chosen to limit the photon ...
Read more >hacat human keratinocytes: Topics by Science.gov
In HaCaT cells exposure to PCB153 significantly reduced telomerase activity ... Inhibitor studies further support a role for endocytosis during HSV-1 entry.
Read more >Re-engineering Philosophy for Limited Beings: Piecewise ...
methods eliminate or analyze away what is being analyzed or reduced. ... pivotal assumptions that play a role in generating our philosophical.
Read more >Abstract - Europe PMC
The role of urinary biomarkers (IL-6, TGF B, MCP-1, E-CADHERIN) for monitoring ... Robotic surgery more ergonomic thus reduces the chances of errors...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@froystig Your solution worked like a charm 👍
A lot has changed over the past year that could have affected this, including our entire approach to staging out functions and what we hoist as a constant, so bisection would require some effort and might not be too enlightening. The current state is that we don’t support closure-captured values, but we probably want to.
Setting aside @jakevdp’s observation (which is worth looking into separately), a simple workaround is to inline the expression for
magic
intocombine_fn
rather than closing over the resulting value. It isn’t especially elegant, but this works for me:We should look into supporting closure. While we ponder that, does the above at least unblock you and get things compiling?