Gradient Bug (Possibly to do with `eigh`).
See original GitHub issueWe’ve been working on some code to simulate rigid bodies using JAX and came across a gradient bug (in the sense that jax.check_grads
fails. Unfortunately, it was fairly difficult to minimize the repro so it’s a bit long. I’m happy to iterate to try to narrow down the problem if it would be helpful.
Thanks very much for any help!
Here’s the code, run using the most recent version of JAX (v 0.3.14) both on Colab and Desktop:
import jax.numpy as jnp
import jax.test_util as jtu
from jax import random
from jax import vmap, jit
from functools import partial
from collections import namedtuple
RigidBody = namedtuple('RigidBody', ['center', 'orientation'])
Array = jnp.ndarray
@partial(jnp.vectorize, signature='(q),(q)->(q)')
def _quaternion_multiply(lhs: Array, rhs: Array) -> Array:
wl, xl, yl, zl = lhs
wr, xr, yr, zr = rhs
return jnp.array([
-xl * xr - yl * yr - zl * zr + wl * wr,
xl * wr + yl * zr - zl * yr + wl * xr,
-xl * zr + yl * wr + zl * xr + wl * yr,
xl * yr - yl * xr + zl * wr + wl * zr
])
@partial(jnp.vectorize, signature='(q)->(q)')
def _quaternion_conjugate(q: Array) -> Array:
w, x, y, z = q
return jnp.array([w, -x, -y, -z], dtype=q.dtype)
@partial(jnp.vectorize, signature='(q),(d)->(d)')
def _quaternion_apply(q: Array, v: Array) -> Array:
if q.shape != (4,):
raise ValueError('')
if v.shape != (3,):
raise ValueError('')
v = jnp.concatenate([jnp.zeros((1,), v.dtype), v])
q = _quaternion_multiply(q, _quaternion_multiply(v, _quaternion_conjugate(q)))
return q[1:]
def transform(body: RigidBody, points: Array) -> Array:
position, orientation = body
return position[None, :] + _quaternion_apply(orientation, points)
transform = vmap(transform, (0, None))
def moment_of_inertia(points):
ndim = points.shape[-1]
I_sphere = 2 / 5
@vmap
def per_particle(point):
diagonal = jnp.linalg.norm(point) ** 2 * jnp.eye(point.shape[-1])
off_diagonal = point[:, None] * point[None, :]
return ((diagonal - off_diagonal) + jnp.eye(3) * I_sphere)
return jnp.sum(per_particle(points), axis=0)
def transform_to_diagonal_frame(shape_points):
I = moment_of_inertia(shape_points)
I_diag, U = jnp.linalg.eigh(I)
shape_points = jnp.einsum('ni,ij->nj', shape_points, U)
return shape_points
def _diagonal_mask(X: Array) -> Array:
"""Sets the diagonal of a matrix to zero."""
if X.shape[0] != X.shape[1]:
raise ValueError(
'Diagonal mask can only mask square matrices. Found {}x{}.'.format(
X.shape[0], X.shape[1]))
if len(X.shape) > 3:
raise ValueError(
('Diagonal mask can only mask rank-2 or rank-3 tensors. '
'Found {}.'.format(len(X.shape))))
N = X.shape[0]
# NOTE(schsam): It seems potentially dangerous to set nans to 0 here. However,
# masking nans also doesn't seem to work. So it also seems necessary. At the
# very least we should do some @ErrorChecking.
X = jnp.nan_to_num(X)
mask = 1.0 - jnp.eye(N, dtype=X.dtype)
if len(X.shape) == 3:
mask = jnp.reshape(mask, (N, N, 1))
return mask * X
def safe_mask(mask, fn, operand, placeholder=0):
masked = jnp.where(mask, operand, 0)
return jnp.where(mask, fn(masked), placeholder)
def energy_fn(body, shape_points):
R = transform(body, shape_points)
R = jnp.reshape(R, (-1, R.shape[-1]))
dr_2 = jnp.sum((R[:, None, :] - R[None, :, :]) ** 2, axis=-1)
dr = safe_mask(dr_2 > 0, jnp.sqrt, dr_2)
e = 0.5 * jnp.where(dr < 1.0, (1.0 - dr) ** 2, 0.0)
return 0.5 * jnp.sum(_diagonal_mask(e))
def shape_energy_fn(shape_points):
body = RigidBody(
jnp.array([[0.0, 0.0, 0.0],
[0.5, 0.25, 0.15]]),
jnp.array([[1.0, 0.0, 0.0, 0.0],
[1.0, 0.1, 0.0, 0.0]])
)
shape_points = transform_to_diagonal_frame(shape_points)
return energy_fn(body, shape_points)
points = jnp.array([[-0.5, -0.5, -0.5],
[-0.5, -0.5, 0.5],
[ 0.5, -0.5, -0.5],
[ 0.5, -0.5, 0.5],
[-0.5, 0.5, -0.5],
[-0.5, 0.5, 0.5],
[ 0.5, 0.5, -0.5],
[ 0.5, 0.5, 0.5]]))
# If these two lines are swapped then the test passes. The only
# difference between the two is whether or not the moment of inertia
# tensor (that is passed into `eigh`) has eigenvalues that are already ordered
# or whether they are out of order.
points = points * jnp.array([[1.0, 1.1, 1.2]])
# points = points * jnp.array([[1.2, 1.1, 1.0]])
jtu.check_grads(shape_energy_fn, (points,), 1)
This asserts with the error
Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.5010784
Max relative difference: 1.2183625
x: array(0.089806, dtype=float32)
y: array(-0.411272, dtype=float32)
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Exposing Numerical Bugs in Deep Learning via Gradient Back ...
Our evaluation on 63 real- world DL programs shows that GRIST detects 78 bugs including 56 unknown bugs. By submitting them to the ......
Read more >Gradient Elution - an overview | ScienceDirect Topics
Gradient elution is basically used for three main purposes: (1) Reduction of the total runtime of separations (see Section 1.4), (2) modification of...
Read more >A simple, high-resolution, non-destructive method for ...
Rajabi H. 2018A simple, high-resolution, non-destructive method for determining the spatial gradient of the elastic modulus of insect cuticle ...
Read more >Gradient Workflows Tutorial - Paperspace Docs
Welcome to Gradient Workflows! In this tutorial we'll cover everything you need to know to start automating machine learning tasks and creating reproducible ......
Read more >Switching sequence optimization for gradient error ...
Index Terms—Digital-to-analog converter, gradient error, non- linearity, switching sequence ... Linearity can be achieved by overcoming all possible random.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks @sschoenholz for the more concise version! This repros for me on a Colab CPU runtime, as well as on my own macbook.
It’s likely that in this example the bug is related with eigh. I open a new issue proving that eigh gradients can be wrong: https://github.com/google/jax/issues/10877