question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Scaling behaviour of jax.test_util.check_grads

See original GitHub issue

It seems like check_grads has incorrect scaling behaviour for large input tensors.

In the code of check_jvp (which is called by check_grads), the following line essentially performs a weighted-sum reduction over args.size values:

v_out, t_out = f_jvp(args, tangent)

The vector tangent has a length which grows with the dimension of args. This would warrant scaling of tolerances by args.size. However, such scaling is not performed. Instead, t_out is passed into

check_close(t_out, t_out_expected, atol=atol, rtol=rtol)

which will eventually call:

_assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size)

Thus, the scaling will happen by t_out.size which is the height of the Jacobian, instead of the width args.size, along which the reduction actually happened. Moreover, _assert_numpy_allclose actually performs element-wise comparison with the given tolerances and does

cond = reduced.all()

reduction on a boolean array. So no need for tolerance scaling in this function. If tolerance scaling is necessary, it should happen before, not here.

This behaviour looks incorrect to me. Am I missing something?

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, May 19, 2020

I thought that custom_vjp decoration wouldn’t affect primal calculation when the forward pass is not doing anything except calling the base function.

That shouldn’t be too hard!

I’m going to try out the fix @nsavinov suggests and report back.

1reaction
nsavinovcommented, May 18, 2020

Both of your conclusions are true:

  1. For very wide Jacobians, the test is incorrectly strict and will fail without reason. This is the case of large input tensors and scalar output. This happened to me and I started digging. 😃

  2. For very high Jacobians, the test is incorrectly loose. This is the case of artificial output broadcasting that you mentioned.

If we want a minimal fix, we could swap scaling/not scaling for check_jvp like this:

check_close(t_out, t_out_expected, atol=atol * args.size, rtol=rtol * args.size)
_assert_numpy_allclose(a, b, atol=atol, rtol=rtol)

Some questions remain with this minimal fix though:

  1. Should something be done for check_vjp or is it already correct? I haven’t looked at that code.

  2. I am not sure if scaling rtol (including in my fix above) is needed. It feels like atol could be enough? This example shows a problem with atol but not with rtol:

Assume a == b within atol, rtol:
abs(a - b) = atol + rtol * b ("just within")
Now if we take a1 = 100 * a, b1 = 100 * b:
abs(a1 - b1) = 100 * atol + rtol * 100 * b = 100 * atol + rtol * b1 >> atol + rtol * b1
  1. Is linear scaling the right way to do it? If it is, maybe still worth a comment in the source code so that readers can follow through?

  2. Maybe worth to clarify for the users what kind of functions check_grads expects? I.e. vector/scalar input, vector/scalar output?

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found