question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Excessive loss of precision

See original GitHub issue

I’m seeing an excessive loss of precision far beyond what is typically acceptable for 64-bit calculations:

from jax.config import config; config.update("jax_enable_x64", True)

import jax
import jax.numpy as np
import numpy as anp
from jax.test_util import check_grads
import unittest

class PeriodicTorsion():

    def __init__(self,
        torsion_idxs,
        param_idxs):
        """
        This implements a periodic torsional potential expanded out into three terms:

        V(a) = k0*(1+cos(1 * a - t0)) + k1*(1+cos(2 * a - t1)) + k2*(1+cos(3 * a - t2))

        Parameters:
        -----------
        torsion_idxs: [num_torsions, 4] np.array
            each element (a, b, c, d) is a torsion of four atoms, defined as
            as the angle of the plane defined by the three bond vectors a-b, b-c, c-d. 

        param_idxs: [num_torsions, 6] np.array
            each element (k, phase, periodicity) maps into params for angle constants and ideal angles

        """
        self.torsion_idxs = torsion_idxs
        self.param_idxs = param_idxs

    @staticmethod
    def get_signed_angle(ci, cj, ck, cl):
        """
        The torsion angle between two planes should be periodic but not
        necessarily symmetric. We use an identical but numerically stable arctan2
        implementation as opposed to the OpenMM energy function to avoid a
        singularity when the angle is zero. This is different from the CHARMM convention
        of symmetric torsions.
        """

        # Taken from the wikipedia arctan2 implementation:
        # https://en.wikipedia.org/wiki/Dihedral_angle

        rij = ci - cj
        rkj = ck - cj
        rkl = ck - cl

        n1 = np.cross(rij, rkj)
        n2 = np.cross(rkj, rkl)

        lhs = np.linalg.norm(n1, axis=-1)
        rhs = np.linalg.norm(n2, axis=-1)
        bot = lhs * rhs

        y = np.sum(np.multiply(np.cross(n1, n2), rkj/np.linalg.norm(rkj, axis=-1, keepdims=True)), axis=-1)
        x = np.sum(np.multiply(n1, n2), -1)

        return np.arctan2(y, x)

    def angles(self, conf):
        ci = conf[self.torsion_idxs[:, 0]]
        cj = conf[self.torsion_idxs[:, 1]]
        ck = conf[self.torsion_idxs[:, 2]]
        cl = conf[self.torsion_idxs[:, 3]]

        angle = self.get_signed_angle(ci, cj, ck, cl)
        return angle

    def energy(self, conf, params):
        """
        Compute the torsional energy.
        """
        ci = conf[self.torsion_idxs[:, 0]]
        cj = conf[self.torsion_idxs[:, 1]]
        ck = conf[self.torsion_idxs[:, 2]]
        cl = conf[self.torsion_idxs[:, 3]]

        ks = params[self.param_idxs[:, 0]]
        phase = params[self.param_idxs[:, 1]]
        period = params[self.param_idxs[:, 2]]
        angle = self.get_signed_angle(ci, cj, ck, cl)
        nrg = ks*(1+np.cos(period * angle - phase))
        return np.sum(nrg, axis=-1)


class TestPeriodicTorsion(unittest.TestCase):

    def setUp(self):
        self.conformers = anp.array([
            [[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
             [ 0.561317027011325 , 0.2066950040043141, 0.3670430960815993],
             [-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
             [ 0.9399773448903637,-0.6888774474110431, 0.2104211949995816]],
            [[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
             [ 0.5613170270113252, 0.2066950040043142, 0.3670430960815993],
             [-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
             [ 1.283345455745044 ,-0.0356257425880843,-0.2573923896494185]],
            [[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
             [ 0.561317027011325 , 0.2066950040043142, 0.3670430960815992],
             [-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
             [ 1.263820400176392 , 0.7964992122869241, 0.0084568741589791]],
            [[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
             [ 0.5613170270113252, 0.2066950040043142, 0.3670430960815992],
             [-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
             [ 0.8993534242298198, 1.042445571242743 , 0.7635483993060286]],
            [[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
             [ 0.5613170270113255, 0.2066950040043142, 0.3670430960815993],
             [-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
             [ 0.5250337847650304, 0.476091386095139 , 1.3136545198545133]],
            [[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
             [ 0.5613170270113255, 0.2066950040043141, 0.3670430960815993],
             [-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
             [ 0.485009232042489 ,-0.3818599172073237, 1.1530102055165103]],
            ], dtype=anp.float64)

        self.nan_conformers = anp.array([
            [[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
             [ 0.5613170270113252, 0.2066950040043142, 0.3670430960815993],
             [-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
             [ 1.2278668040866427, 0.8805184219394547, 0.099391329616366 ]],
            [[-0.6000563454193615, 0.376172954382274 ,-0.2487295756125901],
             [ 0.561317027011325 , 0.206695004004314 , 0.3670430960815994],
             [-1.187055522272264 ,-0.3415864358441354, 0.0871382207830652],
             [ 0.5494071252089705,-0.5626592973923106, 0.9817919758125693]],
            ], dtype=anp.float64)


    def test_gpu_torsions(self):
        """
        Test agreement of torsions with OpenMM's implementation of torsion terms.
        """
        torsion_idxs = anp.array([
            [0, 1, 2, 3],
            # [0, 1, 2, 3],
            # [0, 1, 2, 3],
        ], dtype=anp.int32)

        params_np = anp.array([
            2.3, # k0
            5.4, # k1
            9.0, # k2
            0.0, # t0
            3.0, # t1
            5.8, # t2
            1.0, # n0
            2.0, # n1
            3.0  # n2
        ], dtype=anp.float64)

        param_idxs = anp.array([
            # [0, 3, 6], # works
            [1, 4, 7], # fails
            # [2, 5, 8] # fails
        ], dtype=anp.int32)

        ref_nrg = PeriodicTorsion(
            param_idxs=param_idxs,
            torsion_idxs=torsion_idxs
        )

        for conf_idx, conf in enumerate(self.conformers):
            nrg = ref_nrg.energy(conf, params_np)
            angles = ref_nrg.angles(conf)

            print(nrg, angles)

            # first order derivatives pass
            check_grads(ref_nrg.angles, (conf,), order=1) # passes
            check_grads(ref_nrg.energy, (conf, params_np), order=1) # passes


            # second order derivatives fail miserably, both of these should pass in 64bit
            check_grads(ref_nrg.angles, (conf,), order=2) # fails
            check_grads(ref_nrg.energy, (conf, params_np), order=2) # fails




if __name__ == "__main__":
    unittest.main()

The failures look like:

======================================================================
FAIL: test_gpu_torsions (__main__.TestPeriodicTorsion)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "timemachine/jax_functionals/loss_of_precision_repro.py", line 175, in test_gpu_torsions
    check_grads(ref_nrg.energy, (conf, params_np), order=2) # fails
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 153, in check_grads
    check_grads(f_vjp, args, order - 1, atol=atol, rtol=rtol, eps=eps)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 160, in check_grads
    check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 144, in check_vjp
    check_close(ip, ip_expected, atol=atol, rtol=rtol)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 89, in check_close
    assert tree_all(tree_multimap(close, xs, ys)), '\n{} != \n{}'.format(xs, ys)
AssertionError: 
55.438952303284566 != 
55.43906353281106

======================================================================
FAIL: test_gpu_torsions (__main__.TestPeriodicTorsion)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "timemachine/jax_functionals/loss_of_precision_repro.py", line 174, in test_gpu_torsions
    check_grads(ref_nrg.angles, (conf,), order=2) # fails
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 153, in check_grads
    check_grads(f_vjp, args, order - 1, atol=atol, rtol=rtol, eps=eps)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 160, in check_grads
    check_vjp(f, partial(api.vjp, f), args, atol, rtol, eps)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 144, in check_vjp
    check_close(ip, ip_expected, atol=atol, rtol=rtol)
  File "/Users/hessian/venv/lib/python3.6/site-packages/jax/test_util.py", line 89, in check_close
    assert tree_all(tree_multimap(close, xs, ys)), '\n{} != \n{}'.format(xs, ys)
AssertionError: 
0.381797537133656 != 
0.381794076418857

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:22 (20 by maintainers)

github_iconTop GitHub Comments

1reaction
proteneercommented, Apr 26, 2019

Thanks - I think we’re in agreement here.

0reactions
gnoolcommented, Nov 2, 2020

@proteneer Actually, I observe that if I turn on double precision for JAX, the results then become similar again with those from finite difference. So it does seem like in my case above the error was coming from JAX due to single precision? Have to say I am a bit reluctant to admit that I really need double precision in JAX… (was hoping I could save some computational resource by using single precision)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Best algorithm for avoiding loss of precision?
Another rule of thumb usually used is this: When adding a long series of numbers, start adding from numbers closest to zero and...
Read more >
Tips For Handling Tricky Floating Point Arithmetic
Overflow; Underflow; Loss of precision in converting into floating point; Adding numbers of very different magnitudes; Subtracting numbers of similar ...
Read more >
What loss function should one use to get a high precision or ...
Now I want to control recall/precision of my classifier so, for example, it will not wrongly label too much of a majority class...
Read more >
Discussions: Designer: Numeric Data Precision
This can cause some loss of precision or rounding to occur that you may not expect. Another thing to keep in mind is...
Read more >
Floating-point arithmetic
In computing, floating-point arithmetic (FP) is arithmetic that represents real numbers approximately, using an integer with a fixed precision, called the ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found