question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

plot_confusion_matrix example breaks down if not all classes are present in the test data

See original GitHub issue

Description

The example at https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py

easily breaks down without warning or error if the data does not contain all labels. This can easily happen with imbalanced datasets or with many classes and real datasets.

Steps/Code to Reproduce

import itertools
import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=45, test_size=0.05)
    

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear')
y_pred = classifier.fit(X_train, y_train).predict(X_test)


def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()


# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                      title='Confusion matrix, without normalization')
plt.show()

Expected Results

good

Actual Results

bad

Versions

Windows-7-6.1.7601-SP1 Python 3.6.5 |Anaconda, Inc.| (default, Mar 29 2018, 13:32:41) [MSC v.1900 64 bit (AMD64)] NumPy 1.14.3 SciPy 1.1.0 Scikit-Learn 0.19.1

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Reactions:2
  • Comments:16 (13 by maintainers)

github_iconTop GitHub Comments

1reaction
qinhanmin2014commented, Feb 9, 2019

@sanu11 @MLopez-Ibanez has something in his fork and there’s a PR from @trungpham10, feel free to take it if they don’t mind or if there’s no reply after some time.

1reaction
MLopez-Ibanezcommented, Dec 17, 2018

That’s ok with me.

On Mon, 17 Dec 2018, 09:24 Adrin Jalali <notifications@github.com wrote:

Sure, go ahead @trungpham10 https://github.com/trungpham10 , unless @MLopez-Ibanez https://github.com/MLopez-Ibanez is working on a PR.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/scikit-learn/scikit-learn/issues/12700#issuecomment-447775770, or mute the thread https://github.com/notifications/unsubscribe-auth/ACf6ddh8JiPVNKflTOv42tEipmaVYjCjks5u52LVgaJpZM4Y7uLj .

Read more comments on GitHub >

github_iconTop Results From Across the Web

What Is a Confusion Matrix and How Do You Plot It? - Turing
A confusion matrix is an N X N matrix that is used to evaluate the performance of a classification model, where N is...
Read more >
Confusion Matrix for Machine Learning - Analytics Vidhya
A Confusion matrix is an N x N matrix used for evaluating the performance of a classification model, where N is the number...
Read more >
Confusion Matrix - an overview | ScienceDirect Topics
Confusion matrix is a very popular measure used while solving classification problems. It can be applied to binary classification as well as for...
Read more >
Simple guide to confusion matrix terminology - Data School
A confusion matrix is a table that is often used to describe the performance of a classification model (or "classifier") on a set...
Read more >
What is a Confusion Matrix in Machine Learning
When your data has more than 2 classes. With 3 or more classes you may get a classification accuracy of 80%, but you...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found