+2 votes
in Machine Learning by (74.2k points)
recategorized by

When I use sklearn function cross_val_predict() for the classification task, it returns the predicted labels of the input data. How can I get the predicted probability instead of the predicted class?

E.g.

For the iris dataset, the output of cross_val_predict() using XGboost model is as follows:

>>> y_predict = cross_val_predict(clf, X, y, cv=5)

>>> y_predict

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

       1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,

       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2,

       2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,

       2, 2, 1, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

1 Answer

+1 vote
by (349k points)
selected by
 
Best answer

cross_val_predict() function has a parameter 'method'. If you set this parameter to 'predict_proba', you can get the probability of each class for the input data.

E.g. If the data has two classes - 0 and 1. The cross_val_predict() function will give you the probability of both class 0 and 1 for each of the input record.

Here is an example using IRIS dataset and XGboost:

>>> from sklearn.datasets import load_iris
>>> data = load_iris()
>>> from sklearn.model_selection import cross_val_predict
>>> import xgboost as xgb
>>> X=data.data
>>> y=data.target
>>> clf = xgb.XGBClassifier()
>>> y_predict = cross_val_predict(clf, X, y, cv=5, method='predict_proba')
>>> y_predict
array([[0.99147004, 0.00644642, 0.00208359],
       [0.9914602 , 0.00644635, 0.0020934 ],
       [0.99147004, 0.00644642, 0.00208359],
       [0.99147004, 0.00644642, 0.00208359],
       [0.99147004, 0.00644642, 0.00208359],
       [0.99147004, 0.00644642, 0.00208359],
       [0.99147004, 0.00644642, 0.00208359],
       [0.99147004, 0.00644642, 0.00208359],
       [0.9913379 , 0.00644556, 0.00221657],
       [0.99147004, 0.00644642, 0.00208359],
       [0.9910517 , 0.0062491 , 0.00269931],
       [0.99268746, 0.00460876, 0.00270376],
       [0.992063  , 0.00460586, 0.00333117],
.....

.....

.....

       [0.00298027, 0.00511934, 0.99190044],
       [0.00135648, 0.00233008, 0.9963134 ],
       [0.00218805, 0.01090564, 0.98690635],
       [0.00135674, 0.00214127, 0.99650204],
       [0.00283705, 0.00447759, 0.9926853 ],
       [0.00516285, 0.04157701, 0.9532601 ]], dtype=float32)
>>>


...