question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

LogisticRegression's predict_proba(X) returns (n,) instead of (n,2) for binary class

See original GitHub issue

Just like sklearn’s precit_proba(X), Dask’s documentation says it returns an “array-like, shape = [n_samples, n_classes]”. However, it returns an “array-like, shape = [n_samples, 1]”, for a binary classifier. E.g. run this:

import numpy as np
from sklearn.linear_model import LogisticRegression
from dask_ml.linear_model import LogisticRegression as DaskLogit

# Data
X = np.random.rand(200,20)
y = np.random.rand(200).round(0) # random binary labels

# Dask 
lrdask = DaskLogit()
lrdask.fit(X, y)
print("Logit Reg 'predict_proba' shape Dask\t", lrdask.predict_proba(X).shape)

# SKLearn
lrsklr = LogisticRegression()
lrsklr.fit(X, y)
print("Logit Reg 'predict_proba' shape SKLearn\t", lrsklr.predict_proba(X).shape)

Of course, no information is lost, but this does break stuff when used in other code. E.g. in GridSearchCV’s scorer I get:

C:\ProgramData\Anaconda3\lib\site-packages\sklearn\metrics\scorer.py in __call__(self, clf, X, y, sample_weight)
    184 
    185                 if y_type == "binary":
--> 186                     y_pred = y_pred[:, 1]
    187                 elif isinstance(y_pred, list):
    188                     y_pred = np.vstack([p[:, -1] for p in y_pred]).T

IndexError: too many indices for array

predict_proba(X)'s return value is indeed implemented differently in sklearn and dask-ml: sklearn:

def _predict_proba_lr(self, X):
	...
	return np.vstack([1 - prob, prob]).T

dask-ml:

def predict_proba(self, X):
	...
	return sigmoid(dot(X_, self._coef))

Is this intentional and would changing this break code that currently expects Dask to return (n,) and not (n,2)? Or is this an oversight, and should it return (n,2) as per documentation?

Issue Analytics

  • State:open
  • Created 5 years ago
  • Reactions:1
  • Comments:7 (6 by maintainers)

github_iconTop GitHub Comments

1reaction
TomAugspurgercommented, Aug 26, 2019

Still open, if you’re interested in working on it.

IIUC, there are two issues:

  1. For binary logicist regression, properly (N,) shape output to (N, 2).
  2. Implement multinomial logistic regression.

The first sounds relatively easier, so we might want to start with that. That would be around https://github.com/dask/dask-ml/blob/master/dask_ml/linear_model/glm.py#L248

1reaction
TomAugspurgercommented, Mar 15, 2019

Nope, nothing beyond we should match scikit-learn.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Why does the predict_proba function return 2 columns?
Short answer. In every column it gives you information about the probability, that sample belong to this class (zero column shows the ...
Read more >
LogisticRegression (sklearn) - why does 'predict_proba' yield ...
SKlearn's predict function simply provides binary value based on a threshold of 0.5. This will be interpreted as a vector of probabilities ...
Read more >
An Accessible Library for Machine Learning with ...
in homomorphic encryption and rapidly increasing computing power make this ... The predict log proba(X) method inherits from the LogisticRegression class,.
Read more >
Logistic Regression with Scikit-Learn - Ernesto Garbarino
The regular predict() method returns the class, in this case a bool answer, as opposed to a float value. We use precit_proba() ,...
Read more >
Supervised learning: predicting an output variable from high ...
All supervised estimators in scikit-learn implement a fit(X, y) method to fit the model and a predict(X) method that, given unlabeled observations X...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found