jacfwd gives different results from jacrev
See original GitHub issueWhen working with an electrostatic pairwise potential, I’m seeing nans in jacrev but not jacfwd
import jax
import jax.numpy as np
from jax.scipy.special import erf, erfc
def delta_r(ri, rj, box=None):
diff = ri - rj # this can be either N,N,3 or B,3
if box is not None:
diff -= box[2]*np.floor(np.expand_dims(diff[...,2], axis=-1)/box[2][2]+0.5)
diff -= box[1]*np.floor(np.expand_dims(diff[...,1], axis=-1)/box[1][1]+0.5)
diff -= box[0]*np.floor(np.expand_dims(diff[...,0], axis=-1)/box[0][0]+0.5)
return diff
def distance(ri, rj, box=None):
dxdydz = np.power(delta_r(ri, rj, box), 2)
# np.linalg.norm nans but this doesn't
dij = np.sqrt(np.sum(dxdydz, axis=-1))
return dij
def pairwise(conf, charges, box=None):
num_atoms = conf.shape[0]
qi = np.expand_dims(charges, 0) # (1, N)
qj = np.expand_dims(charges, 1) # (N, 1)
qij = np.multiply(qi, qj)
ri = np.expand_dims(conf, 0)
rj = np.expand_dims(conf, 1)
dij = distance(ri, rj, box)
eij = qij/dij
eij = np.where(np.eye(num_atoms), np.zeros_like(eij), eij) # zero out diagonals
# print(dij)
alphaEwald = 1.0
eij_direct = np.where(dij > 2.0, np.zeros_like(eij), eij)
eij_direct *= erfc(alphaEwald*eij_direct)
eij_direct = np.sum(eij_direct)/2
return np.sum(eij_direct)
if __name__ == "__main__":
charges = np.array([1.3, 0.3, 0.3, 0.3, 0.3], dtype=np.float64)
conf = np.array([
[ 0.0637, 0.0126, 0.2203],
[ 1.0573, -0.2011, 1.2864],
[ 2.3928, 1.2209, -0.2230],
[-0.6891, 1.6983, 0.0780],
[-0.6312, -1.6261, -0.2601]
], dtype=np.float64)
box = np.array([
[2.0, 0.0, 0.0],
[0.6, 1.6, 0.0],
[0.4, 0.7, 1.1]
], dtype=np.float64)
print(jax.jacfwd(pairwise, 0)(conf, charges, box))
print(jax.jacfwd(pairwise, 1)(conf, charges, box))
print(jax.jacrev(pairwise, 0)(conf, charges, box))
print(jax.jacrev(pairwise, 1)(conf, charges, box))
results
[[-0.01737787 0.11672211 0.39640915]
[-0.01465135 -0.13763744 -0.05796134]
[ 0.20212802 0.2141647 -0.0337759 ]
[-0.02587864 0.03783851 -0.05620446]
[-0.14422016 -0.23108788 -0.24846745]]
[-0.12515384 1.015857 0.6813821 0.6616855 0.3078022 ]
[[-0.01737785 0.11672208 0.39640918]
[-0.01465137 -0.13763744 -0.05796135]
[ 0.20212804 0.21416475 -0.03377588]
[-0.02587865 0.03783851 -0.05620446]
[-0.14422016 -0.23108791 -0.24846748]]
[nan nan nan nan nan]
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:5 (5 by maintainers)
Top Results From Across the Web
The Autodiff Cookbook - JAX documentation
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, ...
Read more >Compute efficiently Hessian matrices in JAX - Stack Overflow
To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is ...
Read more >Understanding Autodiff with JAX - Srihari Radhakrishna
JAX is one such framework that can perform autodiff on functions defined in native Python or NumPy code and provides other transformation ...
Read more >Making a for-loop more efficient - numpyro
The algebraic equations are dependent on one another, so I solve for x1 ... I'm getting similar results with scan(), which should work...
Read more >Taylor Approximation and JAX - Cristóbal Alcázar
Another interesting point is that jax.grad allows you to compose ... We can use jax.jacrev to obtain the same results but traverse the...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Funny that you run across this now; if my guess at what’s going on is correct, we were just talking about this at the end of last week. Basically
np.where
is hard to handle correctly when one side produces nans. JAX, Autograd, and TF (and possibly others) all have this bug. We think we know how to solve it in JAX (in a way we couldn’t with Autograd), but it’d take a surprising amount of work.Here’s an even simpler repro:
As you figured out, the workaround is to add more
np.where
calls. The more complete solution will take some time to describe.For the time being, I’m glad this issue is open!
Check out the discussion in #1052, especially this comment. The current executive summary is that this isn’t something easy to fix without introducing other problems; a short-term solution might be to provide an alternative to
np.where
.