Scaling behaviour of jax.test_util.check_grads
See original GitHub issueIt seems like check_grads has incorrect scaling behaviour for large input tensors.
In the code of check_jvp (which is called by check_grads), the following line essentially performs a weighted-sum reduction over args.size
values:
v_out, t_out = f_jvp(args, tangent)
The vector tangent
has a length which grows with the dimension of args
. This would warrant scaling of tolerances by args.size
. However, such scaling is not performed. Instead, t_out
is passed into
check_close(t_out, t_out_expected, atol=atol, rtol=rtol)
which will eventually call:
_assert_numpy_allclose(a, b, atol=atol * a.size, rtol=rtol * b.size)
Thus, the scaling will happen by t_out.size
which is the height of the Jacobian, instead of the width args.size
, along which the reduction actually happened. Moreover, _assert_numpy_allclose
actually performs element-wise comparison with the given tolerances and does
cond = reduced.all()
reduction on a boolean array. So no need for tolerance scaling in this function. If tolerance scaling is necessary, it should happen before, not here.
This behaviour looks incorrect to me. Am I missing something?
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (5 by maintainers)
That shouldn’t be too hard!
I’m going to try out the fix @nsavinov suggests and report back.
Both of your conclusions are true:
For very wide Jacobians, the test is incorrectly strict and will fail without reason. This is the case of large input tensors and scalar output. This happened to me and I started digging. 😃
For very high Jacobians, the test is incorrectly loose. This is the case of artificial output broadcasting that you mentioned.
If we want a minimal fix, we could swap scaling/not scaling for check_jvp like this:
Some questions remain with this minimal fix though:
Should something be done for
check_vjp
or is it already correct? I haven’t looked at that code.I am not sure if scaling
rtol
(including in my fix above) is needed. It feels likeatol
could be enough? This example shows a problem with atol but not with rtol:Is linear scaling the right way to do it? If it is, maybe still worth a comment in the source code so that readers can follow through?
Maybe worth to clarify for the users what kind of functions
check_grads
expects? I.e. vector/scalar input, vector/scalar output?