question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Conditional derivatives should raise a warning or exception

See original GitHub issue

This was really tricky to track down and repro - the bug is manifested when you need to differentiate through a conditional itself (as opposed to its branches). I may be using a different terminology from you folks, so I think some code is most illustrative:

import numpy as onp
import jax
import jax.numpy as np

def foo(x):
    if onp.random.rand() < np.cos(x):
        return 2
    else:
        return 3

def expectation(val):
    total = 0
    counts = 50000
    for _ in range(counts):
        total += foo(val)
    return total/counts

# we can also compute the expectation
# analytically as:

#   P(true)*2 + P(false)*3
# = cos(x)*2 + (1-cos(x))*3

x = expectation(0.5) 
print(x)
# for large counts, this returns exactly 
# 2.12241743811

# the gradient of the expectation should be:
# -sin(x)*2 + sin(x)*3
# 0.4794255386

grad_expectation = jax.jacfwd(expectation, argnums=(0,))
# always returns zero because foo returns constants
print("autodiff", grad_expectation(0.5)) 

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:12 (12 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, May 6, 2019

I mostly disagree with bullet 2; JAX is returning a sensible derivative. That is, $\partial [ x \mapsto f(x, w) ] = 0$ for almost all $x$ and $w$ (i.e. everywhere except when $w = cos(x)$), where $f(x, w) = 2 * I[w < cos(x)] + 3 * I[w \geq cos(x)]$. If you disagree, what do you think the derivative of that function is? (Notice there’s no integration here.)

When $w = cos(x)$ no Frechet derivative exists, and it would be reasonable to raise an error there, but (1) that’s not affecting the integration (we can do surgery on a set of measure zero) so I’m assuming that’s not the error you’re talking about and (2) JAX gives you one of the directional derivatives (the one from the right).

In other words, nothing’s going wrong with the Monte Carlo estimate of the integral of the derivative; the integral of the derivative is zero for all x, as the Monte Carlo approximation reports (with zero variance, in fact!). The problem is that what you actually want is the derivative of the integral, but the derivative of the integral is not equal to the integral of the derivative in general.

If you want to form a valid Monte Carlo approximation to the derivative of the integral, one way to do it would be to find an integral representation that satisfies the conditions of the Leibniz rule. In the recent ML literature we tend to call those “reparameterization gradients” and there’s a bit of a cottage industry in developing them for common expectations. I’m sure there’s a much longer history, with different terminology, in physics.

0reactions
fehiepsicommented, May 7, 2019

do you know of work on automatically forming reparameterization-style Monte Carlo estimators for derivatives of integral representations?

I am not familiar with this line of work. The above discussion is very interesting! I’ll set out some time this weekend to follow up. 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

6.4A Analysis of the embedded conversion option—before ...
A contingent conversion option includes a contingency that determines whether the investor has the right to convert into equity (e.g., ...
Read more >
Errors and Exceptions | Learn You Some Erlang for Great Good!
A list of compile-time errors and warnings, runtime errors in Erlang. ... A throw is a class of exceptions used for cases that...
Read more >
Exception Handling - Lex Jansen
In SAS, just as in Python, asking for permission is simple in straightforward conditional logic statements. However, as business rules, anticipated risks, and ......
Read more >
warnings — Warning control — Python 3.11.1 documentation
Warning messages are normally written to sys.stderr , but their disposition can be changed flexibly, from ignoring all warnings to turning them into...
Read more >
PL/SQL Conditional Compilation - Oracle
How does PL/SQL conditional compilation work? ... The function Pkg.Customer() can raise some documented exceptions. For.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found