Excessive loss of precision
See original GitHub issueI’m seeing an excessive loss of precision far beyond what is typically acceptable for 64-bit calculations:
from jax.config import config; config.update("jax_enable_x64", True)
import jax
import jax.numpy as np
import numpy as anp
from jax.test_util import check_grads
import unittest
class PeriodicTorsion():
def __init__(self,
torsion_idxs,
param_idxs):
"""
This implements a periodic torsional potential expanded out into three terms:
V(a) = k0*(1+cos(1 * a - t0)) + k1*(1+cos(2 * a - t1)) + k2*(1+cos(3 * a - t2))
Parameters:
-----------
torsion_idxs: [num_torsions, 4] np.array
each element (a, b, c, d) is a torsion of four atoms, defined as
as the angle of the plane defined by the three bond vectors a-b, b-c, c-d.
param_idxs: [num_torsions, 6] np.array
each element (k, phase, periodicity) maps into params for angle constants and ideal angles
"""
self.torsion_idxs = torsion_idxs
self.param_idxs = param_idxs
@staticmethod
def get_signed_angle(ci, cj, ck, cl):
"""
The torsion angle between two planes should be periodic but not
necessarily symmetric. We use an identical but numerically stable arctan2
implementation as opposed to the OpenMM energy function to avoid a
singularity when the angle is zero. This is different from the CHARMM convention
of symmetric torsions.
"""
# Taken from the wikipedia arctan2 implementation:
# https://en.wikipedia.org/wiki/Dihedral_angle
rij = ci - cj
rkj = ck - cj
rkl = ck - cl
n1 = np.cross(rij, rkj)
n2 = np.cross(rkj, rkl)
lhs = np.linalg.norm(n1, axis=-1)
rhs = np.linalg.norm(n2, axis=-1)
bot = lhs * rhs
y = np.sum(np.multiply(np.cross(n1, n2), rkj/np.linalg.norm(rkj, axis=-1, keepdims=True)), axis=-1)
x = np.sum(np.multiply(n1, n2), -1)
return np.arctan2(y, x)
def angles(self, conf):
ci = conf[self.torsion_idxs[:, 0]]
cj = conf[self.torsion_idxs[:, 1]]
ck = conf[self.torsion_idxs[:, 2]]
cl = conf[self.torsion_idxs[:, 3]]
angle = self.get_signed_angle(ci, cj, ck, cl)
return angle
def energy(self, conf, params):
"""
Compute the torsional energy.
"""
ci = conf[self.torsion_idxs[:, 0]]
cj = conf[self.torsion_idxs[:, 1]]
ck = conf[self.torsion_idxs[:, 2]]
cl = conf[self.torsion_idxs[:, 3]]
ks = params[self.param_idxs[:, 0]]
phase = params[self.param_idxs[:, 1]]
period = params[self.param_idxs[:, 2]]
angle = self.get_signed_angle(ci, cj, ck, cl)
nrg = ks*(1+np.cos(period * angle - phase))
return np.sum(nrg, axis=-1)
class TestPeriodicTorsion(unittest.TestCase):
def setUp(self):
self.conformers = anp.array([
[[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
[ 0.561317027011325 , 0.2066950040043141, 0.3670430960815993],
[-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
[ 0.9399773448903637,-0.6888774474110431, 0.2104211949995816]],
[[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
[ 0.5613170270113252, 0.2066950040043142, 0.3670430960815993],
[-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
[ 1.283345455745044 ,-0.0356257425880843,-0.2573923896494185]],
[[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
[ 0.561317027011325 , 0.2066950040043142, 0.3670430960815992],
[-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
[ 1.263820400176392 , 0.7964992122869241, 0.0084568741589791]],
[[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
[ 0.5613170270113252, 0.2066950040043142, 0.3670430960815992],
[-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
[ 0.8993534242298198, 1.042445571242743 , 0.7635483993060286]],
[[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
[ 0.5613170270113255, 0.2066950040043142, 0.3670430960815993],
[-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
[ 0.5250337847650304, 0.476091386095139 , 1.3136545198545133]],
[[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
[ 0.5613170270113255, 0.2066950040043141, 0.3670430960815993],
[-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
[ 0.485009232042489 ,-0.3818599172073237, 1.1530102055165103]],
], dtype=anp.float64)
self.nan_conformers = anp.array([
[[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
[ 0.5613170270113252, 0.2066950040043142, 0.3670430960815993],
[-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
[ 1.2278668040866427, 0.8805184219394547, 0.099391329616366 ]],
[[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
[ 0.561317027011325 , 0.206695004004314 , 0.3670430960815994],
[-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
[ 0.5494071252089705,-0.5626592973923106, 0.9817919758125693]],
], dtype=anp.float64)
def test_gpu_torsions(self):
"""
Test agreement of torsions with OpenMM's implementation of torsion terms.
"""
torsion_idxs = anp.array([
[0, 1, 2, 3],
# [0, 1, 2, 3],
# [0, 1, 2, 3],
], dtype=anp.int32)
params_np = anp.array([
2.3, # k0
5.4, # k1
9.0, # k2
0.0, # t0
3.0, # t1
5.8, # t2
1.0, # n0
2.0, # n1
3.0 # n2
], dtype=anp.float64)
param_idxs = anp.array([
# [0, 3, 6], # works
[1, 4, 7], # fails
# [2, 5, 8] # fails
], dtype=anp.int32)
ref_nrg = PeriodicTorsion(
param_idxs=param_idxs,
torsion_idxs=torsion_idxs
)
for conf_idx, conf in enumerate(self.conformers):
nrg = ref_nrg.energy(conf, params_np)
angles = ref_nrg.angles(conf)
print(nrg, angles)
# first order derivatives pass
check_grads(ref_nrg.angles, (conf,), order=1) # passes
check_grads(ref_nrg.energy, (conf, params_np), order=1) # passes
# second order derivatives fail miserably, both of these should pass in 64bit
check_grads(ref_nrg.angles, (conf,), order=2) # fails
check_grads(ref_nrg.energy, (conf, params_np), order=2) # fails
if __name__ == "__main__":
unittest.main()
The failures look like:
======================================================================
FAIL: test_gpu_torsions (__main__.TestPeriodicTorsion)
----------------------------------------------------------------------
Traceback (most recent call last):
File "timemachine/jax_functionals/loss_of_precision_repro.py", line 175, in test_gpu_torsions
check_grads(ref_nrg.energy, (conf, params_np), order=2) # fails
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 153, in check_grads
check_grads(f_vjp, args, order - 1, atol=atol, rtol=rtol, eps=eps)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 160, in check_grads
check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 144, in check_vjp
check_close(ip, ip_expected, atol=atol, rtol=rtol)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 89, in check_close
assert tree_all(tree_multimap(close, xs, ys)), '\n{} != \n{}'.format(xs, ys)
AssertionError:
55.438952303284566 !=
55.43906353281106
======================================================================
FAIL: test_gpu_torsions (__main__.TestPeriodicTorsion)
----------------------------------------------------------------------
Traceback (most recent call last):
File "timemachine/jax_functionals/loss_of_precision_repro.py", line 174, in test_gpu_torsions
check_grads(ref_nrg.angles, (conf,), order=2) # fails
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 153, in check_grads
check_grads(f_vjp, args, order - 1, atol=atol, rtol=rtol, eps=eps)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 160, in check_grads
check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 144, in check_vjp
check_close(ip, ip_expected, atol=atol, rtol=rtol)
File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 89, in check_close
assert tree_all(tree_multimap(close, xs, ys)), '\n{} != \n{}'.format(xs, ys)
AssertionError:
0.381797537133656 !=
0.381794076418857
Issue Analytics
- State:
- Created 4 years ago
- Comments:22 (20 by maintainers)
Top Results From Across the Web
Best algorithm for avoiding loss of precision?
Another rule of thumb usually used is this: When adding a long series of numbers, start adding from numbers closest to zero and...
Read more >Tips For Handling Tricky Floating Point Arithmetic
Overflow; Underflow; Loss of precision in converting into floating point; Adding numbers of very different magnitudes; Subtracting numbers of similar ...
Read more >What loss function should one use to get a high precision or ...
Now I want to control recall/precision of my classifier so, for example, it will not wrongly label too much of a majority class...
Read more >Discussions: Designer: Numeric Data Precision
This can cause some loss of precision or rounding to occur that you may not expect. Another thing to keep in mind is...
Read more >Floating-point arithmetic
In computing, floating-point arithmetic (FP) is arithmetic that represents real numbers approximately, using an integer with a fixed precision, called the ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks - I think we’re in agreement here.
@proteneer Actually, I observe that if I turn on double precision for JAX, the results then become similar again with those from finite difference. So it does seem like in my case above the error was coming from JAX due to single precision? Have to say I am a bit reluctant to admit that I really need double precision in JAX… (was hoping I could save some computational resource by using single precision)