Gradient of tf.pow
See original GitHub issueThere are two issues that I’d like to point out with tf.pow’s gradient function:
First, there’s bug with the gradient of tf.pow w.r.t. the exponent, whereby a non-positive base will result in a NaN for the corresponding exponent’s gradient. This can be reproduced using:
tf.version_core == "0.8.1";
tf.grad(x => tf.pow(tf.scalar(0), x))(tf.scalar(2)).dataSync();
tf.grad(x => tf.pow(tf.scalar(-2), x))(tf.scalar(2)).dataSync();
The backend pow is implemented as pow(abs(a), b) * isEven(round(a))
, where isEven
is an indicator function, isEven: X \to {1,-1}
. Despite the backend, the gradient function takes the log of the base, which may be negative, rather than the non-negative base used in the backend, which was hit with an abs. One solution is to have an “unprotected” backend pow (which only receives non-negative bases) and have the abs and the isEven components in the tf.pow op definition, which would be cached for the backward pass. Although this takes care of the gradient for when the base is negative, we still have an issue for when the base is 0.
When the base is 0, both the gradient of the base and the exponent will be NaN:
tf.version_core == "0.8.1";
tf.grad(x => tf.pow(x, tf.scalar(2)))(tf.scalar(0)).dataSync(); // grad of base
tf.grad(x => tf.pow(tf.scalar(0), x))(tf.scalar(2)).dataSync(); // grad of exponent
As just explained, when the base is 0, the gradient will be NaN for the exponent because we’re taking log(0)
, and will be NaN for the base because we’re dividing by 0. A simple solution would be to zero-out the gradient of both the base and the exponent wherever the base is 0.
This makes sense for the gradients (for a^x when a=0) because
- the derivative w.r.t. the exponent will be a^x * log(a) = 0^x*log(0) = 0 * log(0)
- the derivative w.r.t. the base will be x*0^(x-1) = 0
Fortunately, I believe that using the exponent’s gradient is a rare use case because there’s rarely a path from a variable to an exponent (am I wrong to assume this?), so fixing the gradient of the exponent wouldn’t be such a high priority. However, fixing the gradient of the base should be done because it is commonly encountered, eg in regularization; if someone is using tf.pow (and not tf.square) and a parameter reaches 0, then a NaN will be produced. This will eventually propagate to variables and mess everything up.
Would y’all be open to these changes?
Issue Analytics
- State:
- Created 5 years ago
- Reactions:1
- Comments:8 (7 by maintainers)
Top GitHub Comments
Opened a separate issue for the forward func #350.
This is fixed!