plot_confusion_matrix example breaks down if not all classes are present in the test data
See original GitHub issueDescription
easily breaks down without warning or error if the data does not contain all labels. This can easily happen with imbalanced datasets or with many classes and real datasets.
Steps/Code to Reproduce
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names
# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=45, test_size=0.05)
# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear')
y_pred = classifier.fit(X_train, y_train).predict(X_test)
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
title='Confusion matrix, without normalization')
plt.show()
Expected Results
Actual Results
Versions
Windows-7-6.1.7601-SP1 Python 3.6.5 |Anaconda, Inc.| (default, Mar 29 2018, 13:32:41) [MSC v.1900 64 bit (AMD64)] NumPy 1.14.3 SciPy 1.1.0 Scikit-Learn 0.19.1
Issue Analytics
- State:
- Created 5 years ago
- Reactions:2
- Comments:16 (13 by maintainers)
Top Results From Across the Web
What Is a Confusion Matrix and How Do You Plot It? - Turing
A confusion matrix is an N X N matrix that is used to evaluate the performance of a classification model, where N is...
Read more >Confusion Matrix for Machine Learning - Analytics Vidhya
A Confusion matrix is an N x N matrix used for evaluating the performance of a classification model, where N is the number...
Read more >Confusion Matrix - an overview | ScienceDirect Topics
Confusion matrix is a very popular measure used while solving classification problems. It can be applied to binary classification as well as for...
Read more >Simple guide to confusion matrix terminology - Data School
A confusion matrix is a table that is often used to describe the performance of a classification model (or "classifier") on a set...
Read more >What is a Confusion Matrix in Machine Learning
When your data has more than 2 classes. With 3 or more classes you may get a classification accuracy of 80%, but you...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@sanu11 @MLopez-Ibanez has something in his fork and there’s a PR from @trungpham10, feel free to take it if they don’t mind or if there’s no reply after some time.
That’s ok with me.
On Mon, 17 Dec 2018, 09:24 Adrin Jalali <notifications@github.com wrote: