question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unavoidable "y_true and y_pred contain different number of classes" error inside a CV loop

See original GitHub issue

Description

During cross-validation on a multi-class problem, it’s technically possible to have classes present in the test data that don’t appear in the training data.

Steps/Code to Reproduce

import numpy as np
from sklearn.metrics import make_scorer, log_loss
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from sklearn.naive_bayes import BernoulliNB

rs = np.random.RandomState(1389057)

y = [
    'cow',
    'hedgehog',
    'fox',
    'fox',
    'hedgehog',
    'fox',
    'hedgehog',
    'cow',
    'cow',
    'fox'
]

x = rs.normal([0, 0], [1, 1], size=(len(y), 2))

model = BernoulliNB()

cv = StratifiedKFold(4, shuffle=True, random_state=rs)

param_dist = {
    'alpha': np.logspace(np.log(0.1), np.log(1), 20)
}

search = RandomizedSearchCV(model, param_dist, 5,
                            scoring=make_scorer(log_loss, needs_proba=True), cv=cv)

search.fit(x, y)

Expected Results

Either:

  1. Predicted classes from predict_proba are aligned with classes in the full training data, not just the in-fold subset.
  2. Classes not in the training data are ignored in the test data.

Actual Results

Predicted classes from predict_proba are aligned with classes in the in-fold subset only, but classes not in the training data are still used in the test data, causing the error.

I understand that this is normatively “correct” behavior, but it makes it hard/impossible to use in cross-validation with the existing APIs.

From my perspective, the best solution would be to have RandomizedSearchCV pass a labels=self.classes_ argument to its scorer. I’m not sure how well that generalizes.

Versions

Linux-3.10.0-514.26.2.el7.x86_64-x86_64-with-redhat-7.3-Maipo
Python 3.6.6 |Anaconda, Inc.| (default, Jun 28 2018, 17:14:51) [GCC 7.2.0]
NumPy 1.15.0
SciPy 1.1.0
Scikit-Learn 0.19.1

Issue Analytics

  • State:open
  • Created 5 years ago
  • Reactions:1
  • Comments:15 (6 by maintainers)

github_iconTop GitHub Comments

3reactions
NicolasHugcommented, Aug 7, 2018

What about scoring=make_scorer(log_loss, needs_proba=True, labels=y)?

2reactions
qinjcommented, Nov 8, 2019

Use model.classes_ attribute

y_probs = model.predict_proba(X_test) log_loss(y_test, y_probs, label=model.classes_)

Read more comments on GitHub >

github_iconTop Results From Across the Web

sklearn log_loss different number of classes - Stack Overflow
So, the check that true and pred have similar dimensions doesn't mean that log_loss method will work because true's dimensions change. If you...
Read more >
Machine Learning - USC Bytes
learning is an incredible technology — just not in the way that some people have imagined. This second edition of the book contains...
Read more >
Early Detection of ERP Indicators for Developmental Dyslexia ...
This thesis investigates early indicators for developmental dyslexia in children, using various predictive modeling algorithms on a longitudinal sample of ...
Read more >
Hands-on Machine Learning: Keras-TensorFlow
Keras has a number of functions to load popular datasets in keras.datasets . The dataset is already split for you between a training...
Read more >
Release 1.5.0 xgboost developers
Linux platform. See XGBoost GPU Support. Also we have both stable releases and nightly builds, see below for how to install them. For...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found