SGDClassifier never has the attribute "predict_proba" (even with log or modified_huber loss)
See original GitHub issueDescription
SGDClassifier’s predict_proba() is not compatible with MultiOutputClassifier’s predict_proba() (even when it has the proper loss functions: log or modified_huber).
The incompatibility occurs because estimators implementing SGDClassifier do not have the attribute “predict_proba”; thus, when wrapped by MultiOutputClassifier, predict_proba() raises an error. The error occurs in this file: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/multioutput.py
At this condition:
if not hasattr(self.estimator, "predict_proba"):
raise ValueError("The base estimator should implement"
"predict_proba method")
Just for the overly simplified example below, LogisticRegression classifiers do have the attribute, and those work correctly.
Steps/Code to Reproduce
from sklearn.linear_model import SGDClassifier as online
from sklearn.linear_model import LogisticRegression as log
# use either one because they allow predict_proba() with SGDClassifier alone:
clf_test = online(loss="log", penalty="l2")
#clf_test = online(loss="modified_huber", penalty="l2")
# The problematic condition in MultiOutputClassifier's predict_proba():
if not hasattr(clf_test, "predict_proba"):
print("Don't allow predict_proba() when wrapped by MultiOutputClassifier.")
else:
print("Allow predict_proba() when wrapped by MultiOutputClassifier.")
# By contrast, the logistic regression classifier would work.
clf_test = log()
if not hasattr(clf_test, "predict_proba"):
print("Don't allow predict_proba() when wrapped by MultiOutputClassifier.")
else:
print("Allow predict_proba() when wrapped by MultiOutputClassifier.")
Expected Results
Allow predict_proba() when wrapped by MultiOutputClassifier. Allow predict_proba() when wrapped by MultiOutputClassifier.
Actual Results
Don’t allow predict_proba() when wrapped by MultiOutputClassifier. Allow predict_proba() when wrapped by MultiOutputClassifier.
Versions
Windows-10-10.0.15063 (‘Python’, ‘2.7.11 |Anaconda custom (32-bit)| (default, Mar 4 2016, 15:18:41) [MSC v.1500 32 bit (Intel)]’) (‘NumPy’, ‘1.10.4’) (‘SciPy’, ‘0.17.0’) (‘Scikit-Learn’, ‘0.19.1’)
Issue Analytics
- State:
- Created 6 years ago
- Comments:14 (9 by maintainers)
Top GitHub Comments
no, it’s checking at the right time, but checking the wrong thing.
@jnothman, I think @TomDLT also suggested the same thing but as @amueller pointed out, it would still fail for certain (obscure) cases.
Here’s an edge case I came up with in #12222