question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Gradient Bug (Possibly to do with `eigh`).

See original GitHub issue

We’ve been working on some code to simulate rigid bodies using JAX and came across a gradient bug (in the sense that jax.check_grads fails. Unfortunately, it was fairly difficult to minimize the repro so it’s a bit long. I’m happy to iterate to try to narrow down the problem if it would be helpful.

Thanks very much for any help!

Here’s the code, run using the most recent version of JAX (v 0.3.14) both on Colab and Desktop:

import jax.numpy as jnp
import jax.test_util as jtu
from jax import random
from jax import vmap, jit
from functools import partial

from collections import namedtuple

RigidBody = namedtuple('RigidBody', ['center', 'orientation'])

Array = jnp.ndarray

@partial(jnp.vectorize, signature='(q),(q)->(q)')
def _quaternion_multiply(lhs: Array, rhs: Array) -> Array:
  wl, xl, yl, zl = lhs
  wr, xr, yr, zr = rhs

  return jnp.array([
      -xl * xr - yl * yr - zl * zr + wl * wr,
      xl * wr + yl * zr - zl * yr + wl * xr,
      -xl * zr + yl * wr + zl * xr + wl * yr,
      xl * yr - yl * xr + zl * wr + wl * zr
  ])

@partial(jnp.vectorize, signature='(q)->(q)')
def _quaternion_conjugate(q: Array) -> Array:
  w, x, y, z = q
  return jnp.array([w, -x, -y, -z], dtype=q.dtype)

@partial(jnp.vectorize, signature='(q),(d)->(d)')
def _quaternion_apply(q: Array, v: Array) -> Array:
  if q.shape != (4,):
    raise ValueError('')
  if v.shape != (3,):
    raise ValueError('')

  v = jnp.concatenate([jnp.zeros((1,), v.dtype), v])
  q = _quaternion_multiply(q, _quaternion_multiply(v, _quaternion_conjugate(q)))
  return q[1:]

def transform(body: RigidBody, points: Array) -> Array:
  position, orientation = body
  return position[None, :] + _quaternion_apply(orientation, points)
transform = vmap(transform, (0, None))

def moment_of_inertia(points):
  ndim = points.shape[-1]
  I_sphere = 2 / 5
  @vmap
  def per_particle(point):
    diagonal = jnp.linalg.norm(point) ** 2 * jnp.eye(point.shape[-1])
    off_diagonal = point[:, None] * point[None, :]
    return ((diagonal - off_diagonal) + jnp.eye(3) * I_sphere)
  return jnp.sum(per_particle(points), axis=0)

def transform_to_diagonal_frame(shape_points):
  I = moment_of_inertia(shape_points)
  I_diag, U = jnp.linalg.eigh(I)
  shape_points = jnp.einsum('ni,ij->nj', shape_points, U)
  return shape_points

def _diagonal_mask(X: Array) -> Array:
  """Sets the diagonal of a matrix to zero."""
  if X.shape[0] != X.shape[1]:
    raise ValueError(
        'Diagonal mask can only mask square matrices. Found {}x{}.'.format(
            X.shape[0], X.shape[1]))
  if len(X.shape) > 3:
    raise ValueError(
        ('Diagonal mask can only mask rank-2 or rank-3 tensors. '
         'Found {}.'.format(len(X.shape))))
  N = X.shape[0]
  # NOTE(schsam): It seems potentially dangerous to set nans to 0 here. However,
  # masking nans also doesn't seem to work. So it also seems necessary. At the
  # very least we should do some @ErrorChecking.
  X = jnp.nan_to_num(X)
  mask = 1.0 - jnp.eye(N, dtype=X.dtype)
  if len(X.shape) == 3:
    mask = jnp.reshape(mask, (N, N, 1))
  return mask * X

def safe_mask(mask, fn, operand, placeholder=0):
  masked = jnp.where(mask, operand, 0)
  return jnp.where(mask, fn(masked), placeholder)

def energy_fn(body, shape_points):
  R = transform(body, shape_points)
  R = jnp.reshape(R, (-1, R.shape[-1]))
  dr_2 = jnp.sum((R[:, None, :] - R[None, :, :]) ** 2, axis=-1)
  dr = safe_mask(dr_2 > 0, jnp.sqrt, dr_2)
  e = 0.5 * jnp.where(dr < 1.0, (1.0 - dr) ** 2, 0.0)
  return 0.5 * jnp.sum(_diagonal_mask(e))

def shape_energy_fn(shape_points):
  body = RigidBody(
    jnp.array([[0.0, 0.0, 0.0],
               [0.5, 0.25, 0.15]]),
    jnp.array([[1.0, 0.0, 0.0, 0.0],
               [1.0, 0.1, 0.0, 0.0]])
  )

  shape_points = transform_to_diagonal_frame(shape_points)
  return energy_fn(body, shape_points)

points = jnp.array([[-0.5, -0.5, -0.5],
                    [-0.5, -0.5,  0.5],
                    [ 0.5, -0.5, -0.5],
                    [ 0.5, -0.5,  0.5],
                    [-0.5,  0.5, -0.5],
                    [-0.5,  0.5,  0.5],
                    [ 0.5,  0.5, -0.5],
                    [ 0.5,  0.5,  0.5]]))

# If these two lines are swapped then the test passes. The only 
# difference between the two is whether or not the moment of inertia
# tensor (that is passed into `eigh`) has eigenvalues that are already ordered
# or whether they are out of order.
points = points * jnp.array([[1.0, 1.1, 1.2]])
# points = points * jnp.array([[1.2, 1.1, 1.0]])

jtu.check_grads(shape_energy_fn, (points,), 1)

This asserts with the error

Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.5010784
Max relative difference: 1.2183625
 x: array(0.089806, dtype=float32)
 y: array(-0.411272, dtype=float32)

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
jakevdpcommented, May 27, 2022

Thanks @sschoenholz for the more concise version! This repros for me on a Colab CPU runtime, as well as on my own macbook.

0reactions
rafael-fuentecommented, May 29, 2022

Thanks - if the gradient issue is with eigh, can we reproduce by calling check_grads on transform_to_diagonal_frame? That would pretty significantly reduce the size of the repro.

It’s likely that in this example the bug is related with eigh. I open a new issue proving that eigh gradients can be wrong: https://github.com/google/jax/issues/10877

Read more comments on GitHub >

github_iconTop Results From Across the Web

Exposing Numerical Bugs in Deep Learning via Gradient Back ...
Our evaluation on 63 real- world DL programs shows that GRIST detects 78 bugs including 56 unknown bugs. By submitting them to the ......
Read more >
Gradient Elution - an overview | ScienceDirect Topics
Gradient elution is basically used for three main purposes: (1) Reduction of the total runtime of separations (see Section 1.4), (2) modification of...
Read more >
A simple, high-resolution, non-destructive method for ...
Rajabi H. 2018A simple, high-resolution, non-destructive method for determining the spatial gradient of the elastic modulus of insect cuticle ...
Read more >
Gradient Workflows Tutorial - Paperspace Docs
Welcome to Gradient Workflows! In this tutorial we'll cover everything you need to know to start automating machine learning tasks and creating reproducible ......
Read more >
Switching sequence optimization for gradient error ...
Index Terms—Digital-to-analog converter, gradient error, non- linearity, switching sequence ... Linearity can be achieved by overcoming all possible random.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found