implement AD for `lax.reduce_window`
See original GitHub issueIn this example I define 4 identical pooling functions, https://colab.research.google.com/gist/romanngg/fdd96829faaf0a7e666bccc5fd27ba58/pooling_fail.ipynb
p = r.normal(r.PRNGKey(1), (1, 1))
args = (1, 1), (2, 1), [(0, 0), (1, 0)]
def pool_1(p):
return lax._reduce_window_sum(p, *args)
def pool_2(p):
return lax.reduce_window(p, np.zeros((), p.dtype), lax.add, *args)
pool_1_jit = jit(pool_1)
pool_2_jit = jit(pool_2)
but as you can see differentiating their jitted versions doesn’t work, producing different errors, e.g.
for lax.reduce_window
:
NotImplementedError: Differentiation rule for 'reduce_window' not implemented
or for lax._reduce_window_sum
:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in backend_compile(backend, built_c, options)
383 # we use a separate function call to ensure that XLA compilation appears
384 # separately in Python profiling results
--> 385 return backend.compile(built_c, compile_options=options)
386
387 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
RuntimeError: Invalid argument: Window dimensions {
size: 1
stride: 1
window_dilation: 1
base_dilation: 1
}
dimensions {
size: 1
stride: 1
window_dilation: 1
base_dilation: 1
}
dimensions {
size: 1
stride: 1
padding_low: -1
window_dilation: 1
base_dilation: 1
}
has a negative low padding.
, for instruction %reduce-window = f32[2,1,1]{2,1,0} reduce-window(f32[2,1,2]{2,1,0} %parameter.1, f32[] %constant.5), window={size=1x1x1 pad=0_0x0_0x-1_0}, to_apply=%primitive_computation_add.6, metadata={op_type="reduce_window_sum" op_name="jit(vmap(transpose(jvp(pool_1))))/reduce_window_sum[ base_dilation=(1, 1, 1)\n padding=((0, 0), (0, 0), (0, 0))\n window_dilation=(1, 1, 1)\n window_dimensions=(1, 1, 1)\n window_strides=(1, 1, 1) ]" source_file="<ipython-input-2-ee69b86aa221>" source_line=6}
Failed after simplification
Issue Analytics
- State:
- Created 2 years ago
- Reactions:2
- Comments:5 (5 by maintainers)
Top Results From Across the Web
jax._src.lax.windowed_reductions - JAX documentation
_src.lax import convolution from jax._src.lax import slicing from jax. ... TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, ...
Read more >MaxPool does not yet support jvp · Issue #274 · google/jax
This issue is just to keep track of jvp supoort for MaxPool which has not yet been implemented: from functools import partial from ......
Read more >jax.lax.add Example - Program Talk
Learn how to use python api jax.lax.add. ... the shape of the result according # to https://www.tensorflow.org/xla/operation_semantics#reducewindow. self.
Read more >Fast Finite Width Neural Tangent Kernel
to how JAX and other AD packages have rules for JVPs and VJPs. ... Use when. Jacobian contraction. N O [LDFW2 + OW]...
Read more >Networks - Massachusetts Institute of Technology
In these networks, the physical layer is being implemented as a packet ... direct encodings, all of a somewhat ad hoc nature, with...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks a lot for such detailed explanation Roy!
Thanks for highlighting that! I overlooked it in the notebook that you shared.
What’s happening is a bit tricky, but has to do with how we handle a special case of
lax.reduce_window
as an optimization just before binding the underlying primitive (which has no AD rule):https://github.com/google/jax/blob/e7e5140dc9fd192154e866794424d073e5c8d2b0/jax/_src/lax/lax.py#L1387-L1393
This tries to identify known monoids (an associative operation with an identity element) as the reduction function and initial value. For instance, if we successfully notice a reduction with
(+, 0)
, then we proceed to call an internal_reduce_window_sum
, rather than following the general path. The sum-specific path has AD rules.When jax transforms a function, it interprets it with abstract arguments. The degree of abstraction depends on the transformation involved. For autodiff alone, we still have enough information to concretely identify the
0
part of the(+, 0)
, but once we involve compilation as well, we abstract the computation further (to shapes/dtypes) and can at most say that we have addition with a float32 scalar initial value, i.e. “(+, f32[])
.” Since we can’t guarantee the0
initial value, we continue on the general path, for which there is no AD rule.Note also that, if you were to reduce with an operation other than
+
, this would also send you down the general path, and hence would err with or withoutjit
.I was about to suggest, as a workaround, to hard-code the initial value as a constant zero by rewriting:
But it turns out this actually reproduces your original XLA runtime error, from differentiating
pool_1_jit
, using only publiclax
functions! We must be hitting the same bug here that you encountered when differentiating_reduce_window_sum
directly. In hindsight, your code set up that test case correctly, even if calling internal functions isn’t otherwise allowed. Apologies there!In summary, we have a feature request (for AD on the general path) and a bug (with staging+AD on special path) at the same time.