question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

implement AD for `lax.reduce_window`

See original GitHub issue

In this example I define 4 identical pooling functions, https://colab.research.google.com/gist/romanngg/fdd96829faaf0a7e666bccc5fd27ba58/pooling_fail.ipynb

p = r.normal(r.PRNGKey(1), (1, 1))
args = (1, 1), (2, 1), [(0, 0), (1, 0)]


def pool_1(p):
  return lax._reduce_window_sum(p, *args)


def pool_2(p):
  return lax.reduce_window(p, np.zeros((), p.dtype), lax.add, *args)


pool_1_jit = jit(pool_1)
pool_2_jit = jit(pool_2)

but as you can see differentiating their jitted versions doesn’t work, producing different errors, e.g. for lax.reduce_window:

NotImplementedError: Differentiation rule for 'reduce_window' not implemented

or for lax._reduce_window_sum:

RuntimeError                              Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py in backend_compile(backend, built_c, options)
    383   # we use a separate function call to ensure that XLA compilation appears
    384   # separately in Python profiling results
--> 385   return backend.compile(built_c, compile_options=options)
    386 
    387 def _execute_compiled_primitive(prim, compiled, result_handler, *args):

RuntimeError: Invalid argument: Window dimensions {
  size: 1
  stride: 1
  window_dilation: 1
  base_dilation: 1
}
dimensions {
  size: 1
  stride: 1
  window_dilation: 1
  base_dilation: 1
}
dimensions {
  size: 1
  stride: 1
  padding_low: -1
  window_dilation: 1
  base_dilation: 1
}
 has a negative low padding.
	, for instruction %reduce-window = f32[2,1,1]{2,1,0} reduce-window(f32[2,1,2]{2,1,0} %parameter.1, f32[] %constant.5), window={size=1x1x1 pad=0_0x0_0x-1_0}, to_apply=%primitive_computation_add.6, metadata={op_type="reduce_window_sum" op_name="jit(vmap(transpose(jvp(pool_1))))/reduce_window_sum[ base_dilation=(1, 1, 1)\n                                                     padding=((0, 0), (0, 0), (0, 0))\n                                                     window_dilation=(1, 1, 1)\n                                                     window_dimensions=(1, 1, 1)\n                                                     window_strides=(1, 1, 1) ]" source_file="<ipython-input-2-ee69b86aa221>" source_line=6}

Failed after simplification

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:2
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
romannggcommented, Sep 10, 2021

Thanks a lot for such detailed explanation Roy!

1reaction
froystigcommented, Sep 8, 2021

Thanks for highlighting that! I overlooked it in the notebook that you shared.

What’s happening is a bit tricky, but has to do with how we handle a special case of lax.reduce_window as an optimization just before binding the underlying primitive (which has no AD rule):

https://github.com/google/jax/blob/e7e5140dc9fd192154e866794424d073e5c8d2b0/jax/_src/lax/lax.py#L1387-L1393

This tries to identify known monoids (an associative operation with an identity element) as the reduction function and initial value. For instance, if we successfully notice a reduction with (+, 0), then we proceed to call an internal _reduce_window_sum, rather than following the general path. The sum-specific path has AD rules.

When jax transforms a function, it interprets it with abstract arguments. The degree of abstraction depends on the transformation involved. For autodiff alone, we still have enough information to concretely identify the 0 part of the (+, 0), but once we involve compilation as well, we abstract the computation further (to shapes/dtypes) and can at most say that we have addition with a float32 scalar initial value, i.e. “(+, f32[]).” Since we can’t guarantee the 0 initial value, we continue on the general path, for which there is no AD rule.

Note also that, if you were to reduce with an operation other than +, this would also send you down the general path, and hence would err with or without jit.

I was about to suggest, as a workaround, to hard-code the initial value as a constant zero by rewriting:

import numpy as onp

def pool_2(p):
  return lax.reduce_window(p, onp.zeros((), p.dtype), lax.add, *args)

pool_2_jit = jit(pool_2)
print(jacobian(pool_2_jit)(p))

But it turns out this actually reproduces your original XLA runtime error, from differentiating pool_1_jit, using only public lax functions! We must be hitting the same bug here that you encountered when differentiating _reduce_window_sum directly. In hindsight, your code set up that test case correctly, even if calling internal functions isn’t otherwise allowed. Apologies there!

In summary, we have a feature request (for AD on the general path) and a bug (with staging+AD on special path) at the same time.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax._src.lax.windowed_reductions - JAX documentation
_src.lax import convolution from jax._src.lax import slicing from jax. ... TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, ...
Read more >
MaxPool does not yet support jvp · Issue #274 · google/jax
This issue is just to keep track of jvp supoort for MaxPool which has not yet been implemented: from functools import partial from ......
Read more >
jax.lax.add Example - Program Talk
Learn how to use python api jax.lax.add. ... the shape of the result according # to https://www.tensorflow.org/xla/operation_semantics#reducewindow. self.
Read more >
Fast Finite Width Neural Tangent Kernel
to how JAX and other AD packages have rules for JVPs and VJPs. ... Use when. Jacobian contraction. N O [LDFW2 + OW]...
Read more >
Networks - Massachusetts Institute of Technology
In these networks, the physical layer is being implemented as a packet ... direct encodings, all of a somewhat ad hoc nature, with...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found