question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

CategoricalNB bug with categories present in test but absent in train

See original GitHub issue

Description

Calling predict() / predict_proba() / predict_log_proba() on CategoricalNB model throws IndexError.

Steps/Code to Reproduce

import numpy as np
from sklearn.datasets import make_classification
from sklearn.naive_bayes import CategoricalNB
from sklearn.model_selection import train_test_split

X, y = make_classification(n_features=10, n_classes=3, n_samples=1000, random_state=42,
                                             n_redundant=0, n_informative=6)
X = np.abs(X.astype(np.int))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
model = CategoricalNB().fit(X_train, y_train)

model.predict(X_test)

Expected Results

Predictions for X_test(integer labels).

Actual Results

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-3-551bb8080923> in <module>
     10 model = CategoricalNB().fit(X_train, y_train)
     11 
---> 12 model.predict(X_test)

~/Documents/MachineLearning/onnx_projects/skl_env/lib/python3.6/site-packages/sklearn/naive_bayes.py in predict(self, X)
     75         check_is_fitted(self)
     76         X = self._check_X(X)
---> 77         jll = self._joint_log_likelihood(X)
     78         return self.classes_[np.argmax(jll, axis=1)]
     79 

~/Documents/MachineLearning/onnx_projects/skl_env/lib/python3.6/site-packages/sklearn/naive_bayes.py in _joint_log_likelihood(self, X)
   1217         for i in range(self.n_features_):
   1218             indices = X[:, i]
-> 1219             jll += self.feature_log_prob_[i][:, indices].T
   1220         total_ll = jll + self.class_log_prior_
   1221         return total_ll

IndexError: index 5 is out of bounds for axis 1 with size 5

Versions

System: python: 3.6.8 (v3.6.8:3c6b436a57, Dec 24 2018, 02:04:31) [GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)] executable: /Users/prroy/Documents/MachineLearning/onnx_projects/skl_env/bin/python3 machine: Darwin-19.2.0-x86_64-i386-64bit

Python dependencies: pip: 18.1 setuptools: 40.6.2 sklearn: 0.22.1 numpy: 1.18.0 scipy: 1.4.1 Cython: 0.29.14 pandas: 0.25.3 matplotlib: None joblib: 0.14.1

Built with OpenMP: True

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:5
  • Comments:12 (7 by maintainers)

github_iconTop GitHub Comments

5reactions
glemaitrecommented, Jan 6, 2020

Some categories are present during testing but never seen during training. We should probably have a strategy to handle unknown categories or at least raise a proper error message.

1reaction
jnothmancommented, Jan 28, 2020

I would be happy with that solution.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Python IndexError when using predict() function - Stack Overflow
I'm trying to use Python's sklearn.naive_bayes CategoricalNB() model. I can train the model without errors, but when I try to predict I get ......
Read more >
How to solve mismatch in train and test set after categorical ...
Make one new column in both train and test data and assign 1 and 0 to it respectively. Then concat these two datasets...
Read more >
sklearn.naive_bayes.CategoricalNB
The categorical Naive Bayes classifier is suitable for classification with discrete features that are categorically distributed. The categories of each ...
Read more >
cuML API Reference — cuml 22.10.00 documentation
If only one category is present, the feature will be dropped entirely. ... It does not shift/center the data, and thus does not...
Read more >
arXiv:2009.01521v2 [cs.SE] 29 Oct 2021
scikit-learn (Pedregosa et al., 2011), Spark MLlib to test 53 classification and 19 clustering algorithms that found eleven bugs that we ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found