Getting nan for custom metric while training
See original GitHub issueI implemented Mean Average Precision (MAP@all) in tensorflow like this:
def mean_avg_prec_tf(y_true, y_pred):
dims = tf.shape(y_true)
n = dims[0]
k = dims[1]
_, top_idx = tf.nn.top_k(y_pred, k)
y_true = tf.to_float(y_true)
top_idx = tf.to_float(top_idx)
label_idx = tf.concat(1, [y_true, top_idx])
label_idx = tf.reshape(label_idx, [n, 2, k])
def avg_prec(label_idx):
label = label_idx[0]
idx = label_idx[1]
idx = tf.to_int32(idx)
ordered_pred = tf.gather(label, idx)
prec = ordered_pred * tf.cumsum(ordered_pred)
prec /= tf.to_float(tf.range(1, k + 1))
prec = tf.reduce_sum(prec) / tf.reduce_sum(ordered_pred)
return prec
precs = tf.map_fn(avg_prec, label_idx)
return tf.reduce_mean(precs)
This gives me a nan on training set during training but the correct value for the validation set. Any idea how I can fix this?
Issue Analytics
- State:
- Created 7 years ago
- Comments:5 (2 by maintainers)
Top Results From Across the Web
Custom metric Turns to NaN after many steps in each epoch
As I begin an epoch, I get values printing out for the metrics but after many steps one metrics returns NaN and a...
Read more >Capturing a Training State in TensorFlow | by Chaim Rand
We customize the train step to test for NaN gradients before applying them to the model weights. If a NaN gradient is discovered...
Read more >How to Use Metrics for Deep Learning with Keras in Python
How Keras metrics work and how you can use them when training your ... You can get an idea of how to write...
Read more >Callbacks - Keras 2.0.6. Documentation
Callback that accumulates epoch averages of metrics. This callback is automatically ... Callback that terminates training when a NaN loss is encountered.
Read more >Python tips and tricks - 7: Continuing keras model ... - YouTube
Loading a keras model and continuing training When using custom loss function and metrics .No code to share with this video.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
you can use
K.epsilon()
for1e-12
Added a buffer to the denomicator:
This is working for now. Not sure if there is a clever way to solve this.
Closing the issue. Thanks for the help.