API Inconsitency of predict and predict_proba in SVC
See original GitHub issueWhen using SVC(probability=True)
or the output of SVR(probability=True)
predict_proba
will not necessarily be consistent with predict
, in the sense that,
np.argmax(self.predict_proba(X), axis=1) != self.predict(X)
this is documented in the user guide,
In addition, the probability estimates may be inconsistent with the scores, in the sense that the “argmax” of the scores may not be the argmax of the probabilities. (E.g., in binary classification, a sample may be labeled by predict as belonging to a class that has probability <½ according to predict_proba.) Platt’s method is also known to have theoretical issues.
IMO this is a violation of the API contract and should be fixed.
This is being continuously reported as a bug e.g. https://github.com/scikit-learn/scikit-learn/issues/4800 https://github.com/scikit-learn/scikit-learn/issues/12408 https://github.com/scikit-learn/scikit-learn/issues/12982 and a few stack overflow issues e.g. https://stackoverflow.com/a/17019830
I encountered this in a project where detecting this discrepancy, evaluating the difference and deciding whether predict
or argmax(predict_proba
should be used in the end took some effort.
One pitfall is for instance to use predict
to compute the accuracy, and then predict_proba
for ROC AUC which can lead to somewhat problematic results if the predictions of these methods are not consistent.
Several approaches could be used to fix it,
- Deprecate
probability=True
parameter inSVC
,NuSVC
estimator and suggest usingCalibratedClassifierCV(SVC(), cv=5)
instead. In my quick tests (on sparse data), the latter was actually faster and should yield comparable results that are also consistent between predict andpredict_proba
. Though more benchmarks may be needed. There may also be some variation in the results, as libsvm uses a generalization of Platt scaling in the multiclass case by Wu et al 2014 (cf docs that is not used inCalibratedClassifierCV
as far as I understand?
One possibility could be to deprecate, but keep it to allow access to that functionality in libsvm.
-
Compute
predict
asargmax(predict_proba
whenprobability=True
. This has the disadvantage of changing the results of predict depending on this input parameter. -
Dig into libsvm to understand how it could be fixed there.
Issue Analytics
- State:
- Created 5 years ago
- Reactions:6
- Comments:44 (38 by maintainers)
Top GitHub Comments
Here is a minimal example,
which returns,
Not that using SVC in this use case is good but that’s another story (https://github.com/scikit-learn/scikit-learn/pull/13209)…
Putting aside reports of inconsistent results with
probability=True
which, if true means something is likely wrong in libsvm, is an orthogonal issue and should be addressed in https://github.com/scikit-learn/scikit-learn/issues/13662. This issue is about purely API consistency and decoupling ofprobability=True|False
and consistency ofpredict
andpredict_proba
.To re-iterate; the fact that
predict
andpredict_proba
can be inconsistent is IMO a bug, that breaks the API expectations and no amount of documentation is sufficient to fix it, in my opinion. We have a common test for this that passes because this issue happens only occasionally.So the choices (adapted from the initial issue description) could be, a) deprecate
probability=True
andpredict_proba
, then suggest usingCalibratedClassifierCV + SVC
. This is bound to make users unhappy, who are currently using this option. b) whenprobabilty=True
, compute predict as argmax of predict_proba. It could have been the solution except that it silently breaks backward compatibility. And I guess a lot of users runSVC(probability=True)
in production – we can’t make this change silently. c) deprecateprobability=True
andpredict_proba
, then add e.g.CalibratedSVC
andCalibratedNuSVC
classes (ending withCV
could have been better, but harder to read) that behave asSVC(probability=True)
and where predict is computed from predict_proba. d) deprecateprobability=True
and replace it withprobability='calibrated'
for which predict is computed from predict_proba.I would probably vote for d) unless we are OK introducing 2 other classes in c).
Thoughts @amueller ?