question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jacfwd gives different results from jacrev

See original GitHub issue

When working with an electrostatic pairwise potential, I’m seeing nans in jacrev but not jacfwd

import jax
import jax.numpy as np

from jax.scipy.special import erf, erfc
def delta_r(ri, rj, box=None):
    diff = ri - rj # this can be either N,N,3 or B,3
    if box is not None:
        diff -= box[2]*np.floor(np.expand_dims(diff[...,2], axis=-1)/box[2][2]+0.5)
        diff -= box[1]*np.floor(np.expand_dims(diff[...,1], axis=-1)/box[1][1]+0.5)
        diff -= box[0]*np.floor(np.expand_dims(diff[...,0], axis=-1)/box[0][0]+0.5)
    return diff

def distance(ri, rj, box=None):
    dxdydz = np.power(delta_r(ri, rj, box), 2)
    # np.linalg.norm nans but this doesn't
    dij = np.sqrt(np.sum(dxdydz, axis=-1))
    return dij


def pairwise(conf, charges, box=None):
    num_atoms = conf.shape[0]
    qi = np.expand_dims(charges, 0) # (1, N)
    qj = np.expand_dims(charges, 1) # (N, 1)
    qij = np.multiply(qi, qj)
    ri = np.expand_dims(conf, 0)
    rj = np.expand_dims(conf, 1)
    dij = distance(ri, rj, box)
    eij = qij/dij
    eij = np.where(np.eye(num_atoms), np.zeros_like(eij), eij) # zero out diagonals

    # print(dij)
    alphaEwald = 1.0
    eij_direct = np.where(dij > 2.0, np.zeros_like(eij), eij)
    eij_direct *= erfc(alphaEwald*eij_direct)
    eij_direct = np.sum(eij_direct)/2

    return np.sum(eij_direct)

if __name__ == "__main__":
    charges = np.array([1.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float64)
    conf = np.array([
        [ 0.0637,   0.0126,   0.2203],
        [ 1.0573,  -0.2011,   1.2864],
        [ 2.3928,   1.2209,  -0.2230],
        [-0.6891,   1.6983,   0.0780],
        [-0.6312,  -1.6261,  -0.2601]
    ], dtype=np.float64)


    box = np.array([
        [2.0, 0.0, 0.0],
        [0.6, 1.6, 0.0],
        [0.4, 0.7, 1.1]
    ], dtype=np.float64)

    print(jax.jacfwd(pairwise, 0)(conf, charges, box))
    print(jax.jacfwd(pairwise, 1)(conf, charges, box))

    print(jax.jacrev(pairwise, 0)(conf, charges, box))
    print(jax.jacrev(pairwise, 1)(conf, charges, box))

results

[[-0.01737787  0.11672211  0.39640915]
 [-0.01465135 -0.13763744 -0.05796134]
 [ 0.20212802  0.2141647  -0.0337759 ]
 [-0.02587864  0.03783851 -0.05620446]
 [-0.14422016 -0.23108788 -0.24846745]]
[-0.12515384  1.015857    0.6813821   0.6616855   0.3078022 ]
[[-0.01737785  0.11672208  0.39640918]
 [-0.01465137 -0.13763744 -0.05796135]
 [ 0.20212804  0.21416475 -0.03377588]
 [-0.02587865  0.03783851 -0.05620446]
 [-0.14422016 -0.23108791 -0.24846748]]
[nan nan nan nan nan]

Issue Analytics

  • State:open
  • Created 4 years ago
  • Reactions:1
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Apr 30, 2019

Funny that you run across this now; if my guess at what’s going on is correct, we were just talking about this at the end of last week. Basically np.where is hard to handle correctly when one side produces nans. JAX, Autograd, and TF (and possibly others) all have this bug. We think we know how to solve it in JAX (in a way we couldn’t with Autograd), but it’d take a surprising amount of work.

Here’s an even simpler repro:

In [1]: from jax import grad

In [2]: import jax.numpy as np

In [3]: grad(lambda x: np.where(True, x, np.log(x)))(0.)
Out[3]: array(nan, dtype=float32)

In [4]: jvp(lambda x: np.where(True, x, np.log(x)), (0.,), (1.,))
Out[4]: (array(0., dtype=float32), array(1., dtype=float32))

As you figured out, the workaround is to add more np.where calls. The more complete solution will take some time to describe.

For the time being, I’m glad this issue is open!

1reaction
mattjjcommented, Jul 29, 2019

Check out the discussion in #1052, especially this comment. The current executive summary is that this isn’t something easy to fix without introducing other problems; a short-term solution might be to provide an alternative to np.where.

Read more comments on GitHub >

github_iconTop Results From Across the Web

The Autodiff Cookbook - JAX documentation
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, ...
Read more >
Compute efficiently Hessian matrices in JAX - Stack Overflow
To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is ...
Read more >
Understanding Autodiff with JAX - Srihari Radhakrishna
JAX is one such framework that can perform autodiff on functions defined in native Python or NumPy code and provides other transformation ...
Read more >
Making a for-loop more efficient - numpyro
The algebraic equations are dependent on one another, so I solve for x1 ... I'm getting similar results with scan(), which should work...
Read more >
Taylor Approximation and JAX - Cristóbal Alcázar
Another interesting point is that jax.grad allows you to compose ... We can use jax.jacrev to obtain the same results but traverse the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found