Regression logistique

sur le dataset Iris

avec scikit-learn

In [10]:
import pandas as pd
from sklearn.linear_model import LogisticRegression 
df = pd.read_csv('../data/classification/iris.csv')
In [11]:
df.shape
df.describe()
Out[11]:
sepal_length sepal_width petal_length petal_width
count 150.000000 150.000000 150.000000 150.000000
mean 5.843333 3.054000 3.758667 1.198667
std 0.828066 0.433594 1.764420 0.763161
min 4.300000 2.000000 1.000000 0.100000
25% 5.100000 2.800000 1.600000 0.300000
50% 5.800000 3.000000 4.350000 1.300000
75% 6.400000 3.300000 5.100000 1.800000
max 7.900000 4.400000 6.900000 2.500000
In [12]:
clf = LogisticRegression()
In [13]:
clf
Out[13]:
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)
In [14]:
df.head()
Out[14]:
sepal_length sepal_width petal_length petal_width class
0 5.1 3.5 1.4 0.2 Iris-setosa
1 4.9 3.0 1.4 0.2 Iris-setosa
2 4.7 3.2 1.3 0.2 Iris-setosa
3 4.6 3.1 1.5 0.2 Iris-setosa
4 5.0 3.6 1.4 0.2 Iris-setosa
In [15]:
df.columns
Out[15]:
Index(['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'class'], dtype='object')
In [17]:
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']]
X.head()
Out[17]:
sepal_length sepal_width petal_length petal_width
0 5.1 3.5 1.4 0.2
1 4.9 3.0 1.4 0.2
2 4.7 3.2 1.3 0.2
3 4.6 3.1 1.5 0.2
4 5.0 3.6 1.4 0.2
In [19]:
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
y = le.fit_transform(df['class'])

y
Out[19]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
In [22]:
clf.fit(X,y)
Out[22]:
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)
In [23]:
yhat = clf.predict(X)
In [35]:
yhat_proba = clf.predict_proba(X)
yhat_proba
Out[35]:
array([[  8.79681649e-01,   1.20307538e-01,   1.08131372e-05],
       [  7.99706325e-01,   2.00263292e-01,   3.03825365e-05],
       [  8.53796795e-01,   1.46177302e-01,   2.59031285e-05],
       [  8.25383127e-01,   1.74558937e-01,   5.79356669e-05],
       [  8.97323628e-01,   1.02665167e-01,   1.12050036e-05],
       [  9.26986574e-01,   7.30004562e-02,   1.29693872e-05],
       [  8.95064974e-01,   1.04895775e-01,   3.92506205e-05],
       [  8.61839691e-01,   1.38141399e-01,   1.89095833e-05],
       [  8.03156719e-01,   1.96758495e-01,   8.47861140e-05],
       [  7.95421554e-01,   2.04552763e-01,   2.56832240e-05],
       [  8.92083069e-01,   1.07910759e-01,   6.17176870e-06],
       [  8.63364991e-01,   1.36600589e-01,   3.44201798e-05],
       [  7.88177618e-01,   2.11794929e-01,   2.74526810e-05],
       [  8.35079702e-01,   1.64888155e-01,   3.21426418e-05],
       [  9.28349898e-01,   7.16491356e-02,   9.66254924e-07],
       [  9.64535656e-01,   3.54620850e-02,   2.25877936e-06],
       [  9.40906153e-01,   5.90890027e-02,   4.84421830e-06],
       [  8.91740161e-01,   1.08245661e-01,   1.41772124e-05],
       [  8.96525617e-01,   1.03467608e-01,   6.77567332e-06],
       [  9.23615524e-01,   7.63726510e-02,   1.18248373e-05],
       [  8.30668332e-01,   1.69316458e-01,   1.52093733e-05],
       [  9.21914602e-01,   7.80675598e-02,   1.78384021e-05],
       [  9.26584671e-01,   7.34068679e-02,   8.46162713e-06],
       [  8.67785629e-01,   1.32146178e-01,   6.81931916e-05],
       [  8.41271506e-01,   1.58655904e-01,   7.25903122e-05],
       [  7.77263282e-01,   2.22695181e-01,   4.15365716e-05],
       [  8.81389224e-01,   1.18568969e-01,   4.18075826e-05],
       [  8.69974782e-01,   1.30013638e-01,   1.15794893e-05],
       [  8.60034106e-01,   1.39955486e-01,   1.04082979e-05],
       [  8.32052869e-01,   1.67892968e-01,   5.41625519e-05],
       [  8.07811588e-01,   1.92136477e-01,   5.19350231e-05],
       [  8.72544939e-01,   1.27438925e-01,   1.61360155e-05],
       [  9.33948477e-01,   6.60477336e-02,   3.78900866e-06],
       [  9.46250501e-01,   5.37475145e-02,   1.98493064e-06],
       [  7.95421554e-01,   2.04552763e-01,   2.56832240e-05],
       [  8.47610513e-01,   1.52377535e-01,   1.19520539e-05],
       [  8.70019435e-01,   1.29976367e-01,   4.19728170e-06],
       [  7.95421554e-01,   2.04552763e-01,   2.56832240e-05],
       [  8.31024910e-01,   1.68917216e-01,   5.78737851e-05],
       [  8.57737250e-01,   1.42246900e-01,   1.58501104e-05],
       [  9.00222082e-01,   9.97646975e-02,   1.32206853e-05],
       [  6.90741687e-01,   3.09094698e-01,   1.63615590e-04],
       [  8.66068303e-01,   1.33887708e-01,   4.39884356e-05],
       [  9.16308833e-01,   8.36288777e-02,   6.22895883e-05],
       [  9.15519114e-01,   8.44392129e-02,   4.16734713e-05],
       [  8.20309627e-01,   1.79642381e-01,   4.79919885e-05],
       [  9.09855663e-01,   9.01327650e-02,   1.15724381e-05],
       [  8.51214451e-01,   1.48746052e-01,   3.94971199e-05],
       [  8.95519736e-01,   1.04472911e-01,   7.35323849e-06],
       [  8.51563342e-01,   1.48419676e-01,   1.69821772e-05],
       [  2.98900777e-02,   8.60393138e-01,   1.09716785e-01],
       [  3.74487166e-02,   7.05572459e-01,   2.56978825e-01],
       [  1.17957675e-02,   7.48252356e-01,   2.39951876e-01],
       [  1.32920493e-02,   6.51770445e-01,   3.34937506e-01],
       [  1.09868088e-02,   6.98832091e-01,   2.90181101e-01],
       [  1.07669519e-02,   5.83013186e-01,   4.06219862e-01],
       [  2.15200540e-02,   5.37732882e-01,   4.40747064e-01],
       [  1.08418544e-01,   7.68766189e-01,   1.22815267e-01],
       [  1.77270021e-02,   8.27562690e-01,   1.54710308e-01],
       [  3.30493839e-02,   5.28708770e-01,   4.38241846e-01],
       [  2.93117962e-02,   7.72717609e-01,   1.97970595e-01],
       [  4.09569813e-02,   6.19765980e-01,   3.39277039e-01],
       [  1.95378252e-02,   8.79697992e-01,   1.00764183e-01],
       [  8.73285529e-03,   5.96503817e-01,   3.94763328e-01],
       [  1.67434866e-01,   7.12756209e-01,   1.19808925e-01],
       [  4.75535678e-02,   8.43626581e-01,   1.08819852e-01],
       [  1.22530319e-02,   4.23869480e-01,   5.63877488e-01],
       [  3.84753639e-02,   8.50175432e-01,   1.11349204e-01],
       [  3.09968794e-03,   5.96264678e-01,   4.00635634e-01],
       [  3.59781700e-02,   8.08752206e-01,   1.55269624e-01],
       [  6.20745751e-03,   2.73106189e-01,   7.20686354e-01],
       [  5.81151228e-02,   8.19701311e-01,   1.22183566e-01],
       [  1.95840574e-03,   5.33800891e-01,   4.64240703e-01],
       [  8.77628703e-03,   7.04654010e-01,   2.86569703e-01],
       [  3.69274341e-02,   8.38990091e-01,   1.24082475e-01],
       [  3.61807169e-02,   8.28744840e-01,   1.35074443e-01],
       [  8.14489700e-03,   7.77156946e-01,   2.14698157e-01],
       [  4.64006697e-03,   5.23164549e-01,   4.72195384e-01],
       [  1.33500103e-02,   5.63205976e-01,   4.23444014e-01],
       [  1.28473017e-01,   8.31361691e-01,   4.01652917e-02],
       [  3.60902230e-02,   8.03217466e-01,   1.60692311e-01],
       [  5.05096042e-02,   8.46149445e-01,   1.03340951e-01],
       [  5.69724571e-02,   8.11250984e-01,   1.31776559e-01],
       [  1.22453086e-03,   3.99201919e-01,   5.99573550e-01],
       [  1.03123407e-02,   3.65034695e-01,   6.24652965e-01],
       [  4.17476538e-02,   4.77844283e-01,   4.80408063e-01],
       [  1.90525287e-02,   7.45629538e-01,   2.35317933e-01],
       [  7.05352060e-03,   7.56932682e-01,   2.36013798e-01],
       [  5.57541864e-02,   6.67410837e-01,   2.76834977e-01],
       [  2.10790319e-02,   6.62362244e-01,   3.16558724e-01],
       [  8.98003281e-03,   5.99716389e-01,   3.91303578e-01],
       [  1.52196906e-02,   6.32329159e-01,   3.52451150e-01],
       [  3.47695685e-02,   7.98625645e-01,   1.66604786e-01],
       [  9.15416570e-02,   7.95877151e-01,   1.12581192e-01],
       [  1.98418694e-02,   6.40871800e-01,   3.39286330e-01],
       [  4.81040905e-02,   7.31039981e-01,   2.20855929e-01],
       [  3.44565240e-02,   6.77463657e-01,   2.88079819e-01],
       [  3.38822929e-02,   7.96899915e-01,   1.69217792e-01],
       [  2.54574647e-01,   6.90791330e-01,   5.46340233e-02],
       [  3.63488963e-02,   7.04234211e-01,   2.59416893e-01],
       [  1.86036022e-04,   1.48760823e-01,   8.51053141e-01],
       [  8.09069371e-04,   2.94422745e-01,   7.04768186e-01],
       [  2.78126551e-04,   3.30535386e-01,   6.69186488e-01],
       [  4.56288643e-04,   3.38732197e-01,   6.60811514e-01],
       [  2.51393977e-04,   2.57092194e-01,   7.42656412e-01],
       [  6.03186905e-05,   3.82744333e-01,   6.17195349e-01],
       [  2.04838186e-03,   2.81103453e-01,   7.16848165e-01],
       [  1.23247784e-04,   4.24393655e-01,   5.75483097e-01],
       [  1.59929758e-04,   4.23195996e-01,   5.76644074e-01],
       [  3.56390886e-04,   1.52542892e-01,   8.47100717e-01],
       [  2.99635433e-03,   2.78024684e-01,   7.18978962e-01],
       [  6.45242833e-04,   3.55681241e-01,   6.43673516e-01],
       [  6.81029987e-04,   2.98859721e-01,   7.00459249e-01],
       [  6.28418142e-04,   2.96807692e-01,   7.02563890e-01],
       [  6.10997845e-04,   1.74593604e-01,   8.24795398e-01],
       [  1.09757190e-03,   1.73257823e-01,   8.25644605e-01],
       [  7.99254871e-04,   3.48929847e-01,   6.50270898e-01],
       [  1.93443479e-04,   2.38473708e-01,   7.61332849e-01],
       [  1.30064976e-05,   4.20137191e-01,   5.79849802e-01],
       [  6.81548718e-04,   4.69975854e-01,   5.29342597e-01],
       [  5.04477452e-04,   2.25292722e-01,   7.74202801e-01],
       [  1.33913767e-03,   2.30143290e-01,   7.68517573e-01],
       [  3.82097113e-05,   4.28006955e-01,   5.71954836e-01],
       [  2.05299242e-03,   4.00421888e-01,   5.97525119e-01],
       [  6.77847072e-04,   2.37204010e-01,   7.62118143e-01],
       [  4.56383243e-04,   3.97527741e-01,   6.02015876e-01],
       [  3.19858866e-03,   3.83866887e-01,   6.12934525e-01],
       [  3.42364119e-03,   3.27541103e-01,   6.69035256e-01],
       [  3.00544917e-04,   2.98288662e-01,   7.01410793e-01],
       [  6.78376797e-04,   5.10705151e-01,   4.88616472e-01],
       [  1.61719140e-04,   4.27941843e-01,   5.71896438e-01],
       [  6.44775841e-04,   3.44845359e-01,   6.54509865e-01],
       [  2.75279882e-04,   2.78027400e-01,   7.21697320e-01],
       [  2.07731418e-03,   4.90652652e-01,   5.07270034e-01],
       [  3.54683506e-04,   4.42580814e-01,   5.57064503e-01],
       [  1.82017584e-04,   3.42008155e-01,   6.57809828e-01],
       [  6.30908753e-04,   1.28602511e-01,   8.70766580e-01],
       [  9.21940559e-04,   3.20888055e-01,   6.78190005e-01],
       [  4.29311663e-03,   3.18426266e-01,   6.77280618e-01],
       [  1.16680587e-03,   3.00989509e-01,   6.97843685e-01],
       [  4.46290865e-04,   2.02461924e-01,   7.97091785e-01],
       [  2.15227432e-03,   2.48822456e-01,   7.49025270e-01],
       [  8.09069371e-04,   2.94422745e-01,   7.04768186e-01],
       [  2.91162367e-04,   2.24919706e-01,   7.74789132e-01],
       [  4.50477099e-04,   1.53984748e-01,   8.45564775e-01],
       [  1.15724730e-03,   2.33616548e-01,   7.65226205e-01],
       [  9.19025197e-04,   3.79220387e-01,   6.19860588e-01],
       [  1.45811816e-03,   2.98379693e-01,   7.00162189e-01],
       [  1.09779827e-03,   1.31785617e-01,   8.67116585e-01],
       [  1.68397530e-03,   2.81057800e-01,   7.17258224e-01]])
In [34]:
clf.coef_
Out[34]:
array([[ 0.41498833,  1.46129739, -2.26214118, -1.0290951 ],
       [ 0.41663969, -1.60083319,  0.57765763, -1.38553843],
       [-1.70752515, -1.53426834,  2.47097168,  2.55538211]])
In [24]:
yhat
Out[24]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1,
       1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
In [27]:
from sklearn.metrics import confusion_matrix
confusion_matrix(y,yhat)
Out[27]:
array([[50,  0,  0],
       [ 0, 45,  5],
       [ 0,  1, 49]])
In [29]:
clf.score(X,y)
Out[29]:
0.95999999999999996
In [33]:
 
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-33-fb211fc48cfc> in <module>()
      1 from sklearn.metrics import roc_auc_score
----> 2 roc_auc_score(y, yhat)

~/anaconda3/lib/python3.6/site-packages/sklearn/metrics/ranking.py in roc_auc_score(y_true, y_score, average, sample_weight)
    275     return _average_binary_score(
    276         _binary_roc_auc_score, y_true, y_score, average,
--> 277         sample_weight=sample_weight)
    278 
    279 

~/anaconda3/lib/python3.6/site-packages/sklearn/metrics/base.py in _average_binary_score(binary_metric, y_true, y_score, average, sample_weight)
     70     y_type = type_of_target(y_true)
     71     if y_type not in ("binary", "multilabel-indicator"):
---> 72         raise ValueError("{0} format is not supported".format(y_type))
     73 
     74     if y_type == "binary":

ValueError: multiclass format is not supported
In [ ]: