加载所需数据与所需的python库。
import statsmodels.api as sm
import statsmodels.formula.api as smf
import statsmodels.graphics.api as smg
import patsy
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pandas import Series,DataFrame
from scipy import stats
import seaborn as sns
train = pd.read_csv("D:/学习/数据挖掘与机器学习/Titanic/train.csv")
数据集中共有12个字段,PassengerId:乘客编号,Survived:乘客是否存活,Pclass:乘客所在的船舱等级;Name:乘客姓名,Sex:乘客性别,Age:乘客年龄,SibSp:乘客的兄弟姐妹和配偶数量,Parch:乘客的父母与子女数量,Ticket:票的编号,Fare:票价,Cabin:座位号,Embarked:乘客登船码头。 共有891位乘客的数据信息。其中277位乘客的年龄数据缺失,2位乘客的登船码头数据缺失,687位乘客的船舱数据缺失。
train.head()
PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th… | female | 38 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35 | 1 | 0 | 113803 | 53.1000 | C123 | S |
4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35 | 0 | 0 | 373450 | 8.0500 | NaN | S |
train.info()
train.describe()
PassengerId | Survived | Pclass | Age | SibSp | Parch | Fare | |
---|---|---|---|---|---|---|---|
count | 891.000000 | 891.000000 | 891.000000 | 714.000000 | 891.000000 | 891.000000 | 891.000000 |
mean | 446.000000 | 0.383838 | 2.308642 | 29.699118 | 0.523008 | 0.381594 | 32.204208 |
std | 257.353842 | 0.486592 | 0.836071 | 14.526497 | 1.102743 | 0.806057 | 49.693429 |
min | 1.000000 | 0.000000 | 1.000000 | 0.420000 | 0.000000 | 0.000000 | 0.000000 |
25% | 223.500000 | 0.000000 | 2.000000 | 20.125000 | 0.000000 | 0.000000 | 7.910400 |
50% | 446.000000 | 0.000000 | 3.000000 | 28.000000 | 0.000000 | 0.000000 | 14.454200 |
75% | 668.500000 | 1.000000 | 3.000000 | 38.000000 | 1.000000 | 0.000000 | 31.000000 |
max | 891.000000 | 1.000000 | 3.000000 | 80.000000 | 8.000000 | 6.000000 | 512.329200 |
画出训练集中乘客年龄和费用的分布直方图,如下所示。可以发现,大部分乘客的年龄位于20-40岁之间,总体上呈正态分布。大部分乘客的票价很低,位于0-100之间,其他少部分乘客的票价较高。
fig,ax = plt.subplots(nrows=1,ncols=2,figsize=(15,5))
train["Age"].hist(ax=ax[0])
ax[0].set_title("Hist plot of Age")
train["Fare"].hist(ax=ax[1])
ax[1].set_title("Hist plot of Fare")
<matplotlib.text.Text at 0x7672e4e588>
画出乘客获救与没有获救的条形图,如下所示。可以发现,大部分乘客没有获救。
fig,ax = plt.subplots(figsize=(7,5))
train["Survived"].value_counts().plot(kind="bar")
ax.set_xticklabels(("Not Survived","Survived"), rotation= "horizontal" )
ax.set_title("Bar plot of Survived ")
<matplotlib.text.Text at 0x7673102940>
画出乘客性别条形分布图,如下所示。可以发现,大部分乘客为男性。
fig,ax = plt.subplots(figsize=(7,5))
train["Sex"].value_counts().plot(kind="bar")
ax.set_xticklabels(("male","female"),rotation= "horizontal" )
ax.set_title("Bar plot of Sex ")
<matplotlib.text.Text at 0x767307af60>
画出乘客的Pclass条形分布图,如下所示。可以发现,大部分乘客位于第三等级,第一等级和第二等级的乘客各有200个左右。
fig,ax = plt.subplots(figsize=(7,5))
train["Pclass"].value_counts().plot(kind="bar")
ax.set_xticklabels(("Class3","Class1","Class2"),rotation= "horizontal" )
ax.set_title("Bar plot of Pclass ")
<matplotlib.text.Text at 0x76731587b8>
对乘客座位号数据进行处理,将缺失值赋值为Unknown。从乘客座位号数据可以发现,第一个字母可能代表了船舱号码,将该字符提取出来,赋值给Cabin,视为船舱号。
train.Cabin.fillna("Unknown",inplace=True)
for i in range(0, 891):
train.Cabin[i]= train.Cabin[i][0]
D:\software\新建文件夹 (4)\lib\site-packages\ipykernel\__main__.py:3: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy app.launch_new_instance() 画出乘客的船舱号的条形分布图,如下所示。可以发现,大部分乘客的船舱号为未知。
fig,ax = plt.subplots(figsize=(7,5))
train.Cabin.value_counts().plot(kind="bar")
ax.set_title("Bar plot of Cabin ")
<matplotlib.text.Text at 0x76731bd908>
画出乘客兄弟姐妹与配偶数目的条形分布图,如下所示。可以发现,大部分乘客在船上没有兄弟姐妹或配偶,大约200位乘客在船上有1个兄弟姐妹或配偶。
fig,ax = plt.subplots(figsize=(7,5))
train["SibSp"].value_counts().plot(kind="bar")
ax.set_title("Bar plot of SibSp ")
<matplotlib.text.Text at 0x76713d7438>
画出乘客父母与子女数目的条形分布图,如下所示。可以发现,大部分乘客在船上没有父母或子女,100多位乘客在船上有1个兄弟姐妹或配偶,大约90位乘客在船上有2个兄弟姐妹或配偶。
fig,ax = plt.subplots(figsize=(7,5))
train["Parch"].value_counts().plot(kind="bar")
ax.set_title("Bar plot of Parch ")
<matplotlib.text.Text at 0x7671457ef0>
画出乘客出发港口的分布条形图,如下所示。可以发现,大部分乘客从Southampton港口出发,不到200位乘客从Cherburge出发,不到100位乘客从Queentown出发。
fig,ax = plt.subplots(figsize=(7,5))
train["Embarked"].value_counts().plot(kind="bar")
ax.set_xticklabels(("Southampton","Cherbourg","Queenstown"),rotation= "horizontal" )
ax.set_title("Bar plot of Embarked ")
<matplotlib.text.Text at 0x76714d9ac8>
画出性别与是否获救的交叉表和条形图,如下所示。可以发现,女性获救的可能性更高,而男性获救的比例很低。
pd.crosstab(train["Sex"],train["Survived"])
Survived | 0 | 1 |
---|---|---|
Sex | ||
female | 81 | 233 |
male | 468 | 109 |
pd.crosstab(train["Sex"],train["Survived"]).plot(kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x76714d0d68>
画出船舱等级与是否获救的交叉表与条形图,如下所示。可以发现,第一等级的乘客获救的可能性更高,超过50%,第二等级的乘客获救可能性在50%左右,而第三等级的乘客获救可能性很低。
pd.crosstab(train["Pclass"],train["Survived"])
Survived | 0 | 1 |
---|---|---|
Pclass | ||
1 | 80 | 136 |
2 | 97 | 87 |
3 | 372 | 119 |
pd.crosstab(train["Pclass"],train["Survived"]).plot(kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x76741cbf98>
画出兄弟姐妹与配偶数目与是否获救的交叉表与条形图,如下所示。可以发现,有数量为1或2的乘客获救的可能性更高。
pd.crosstab(train["SibSp"],train["Survived"])
Survived | 0 | 1 |
---|---|---|
SibSp | ||
0 | 398 | 210 |
1 | 97 | 112 |
2 | 15 | 13 |
3 | 12 | 4 |
4 | 15 | 3 |
5 | 5 | 0 |
8 | 7 | 0 |
pd.crosstab(train["SibSp"],train["Survived"]).plot(kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x76714a9d68>
画出父母或子女数目与是否获救的交叉表与条形图,如下所示。可以发现,有数量为1或2的乘客获救的可能性更高。
pd.crosstab(train["Parch"],train["Survived"])
Survived | 0 | 1 |
---|---|---|
Parch | ||
0 | 445 | 233 |
1 | 53 | 65 |
2 | 40 | 40 |
3 | 2 | 3 |
4 | 4 | 0 |
5 | 4 | 1 |
6 | 1 | 0 |
pd.crosstab(train["Parch"],train["Survived"]).plot(kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x767435b7b8>
画出登船港口与是否获救的交叉表与条形图,如下所示。可以发现,从Cherburge出发的乘客获救的人数比例更高。
pd.crosstab(train["Embarked"],train["Survived"])
Survived | 0 | 1 |
---|---|---|
Embarked | ||
C | 75 | 93 |
Q | 47 | 30 |
S | 427 | 217 |
pd.crosstab(train["Embarked"],train["Survived"]).plot(kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x767436f4e0>
画出乘客所在船舱与是否获救的交叉表与条形图,如下所示。可以发现,船舱后没有缺失的乘客获救的人数比例更高。
pd.crosstab(train["Cabin"],train["Survived"])
Survived | 0 | 1 |
---|---|---|
Cabin | ||
A | 8 | 7 |
B | 12 | 35 |
C | 24 | 35 |
D | 8 | 25 |
E | 8 | 24 |
F | 5 | 8 |
G | 2 | 2 |
T | 1 | 0 |
U | 481 | 206 |
pd.crosstab(train["Cabin"],train["Survived"]).plot(kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x767441f6d8>
画出乘客是否获救与年龄的箱线图,如下所示。从箱线图上来看,两者关系并不明显。
fig,ay = plt.subplots()
Age1 = train.Age[train.Survived == 1].dropna()
Age0 = train.Age[train.Survived == 0].dropna()
plt.boxplot((Age1,Age0),labels=('Survived','Not Survived'))
ay.set_ylim([-5,70])
ay.set_title("Boxplot of Age")
<matplotlib.text.Text at 0x7674557a20>
画出乘客是否获救与票价的箱线图,如下所示。可以发现,总体而言,获救的乘客票价更高。
fig,ay = plt.subplots()
Fare1 = train.Fare[train.Survived == 1]
Fare0 = train.Fare[train.Survived == 0]
plt.boxplot((Fare1,Fare0),labels=('Survived','Not Survived'))
ay.set_ylim([-10,150])
ay.set_title("Boxplot of Fare")
<matplotlib.text.Text at 0x76745e5470>
画出乘客票价与舱位等级的箱线图,如下所示。可以明显的发现,舱位等级越高的乘客,票价越高。这两个变量之间存在非常明显的线性相关关系。
fig,ay = plt.subplots()
Farec1 = train.Fare[train.Pclass == 1]
Farec2 = train.Fare[train.Pclass == 2]
Farec3 = train.Fare[train.Pclass == 3]
plt.boxplot((Farec1,Farec2,Farec3),labels=("Pclass1","Pclass2","Pclass3"))
ay.set_ylim([-10,180])
ay.set_title("Boxplot of Fare and Pclass")
<matplotlib.text.Text at 0x767466e8d0>
用年龄的均值填充年龄的缺失值,用出发港口的众数填补出发港口的缺失值。
train.Age.mean()
train.Age.fillna(29.7,inplace=True)
train.Embarked.fillna("S",inplace=True)
根据以上分析结果和变量间的关系,将年龄数据分段为0-5岁、5-15岁、15-20岁、20-35岁、35-50岁、50-60岁、60-100岁7段。将Parch变量分成数目为0、数目为1或2、数目为大于2三段。将SibSp变量分成数目为0、数目为1或2、数目为大于2三段。将Cabin变量分为缺失和没有缺失两段。
train.age=pd.cut(train.Age,[0,5,15,20,35,50,60,100])
pd.crosstab(train.age,train.Survived).plot(kind="bar")
<matplotlib.axes._subplots.AxesSubplot at 0x7674657e48>
train.Parch[(train.Parch>0) & (train.Parch<=2)]=1
train.Parch[train.Parch>2]=2
train.SibSp[(train.SibSp>0) & (train.SibSp<=2)]=1
train.SibSp[train.SibSp>2]=2
train.Cabin[train.Cabin!="U"]="K"
D:\software\新建文件夹 (4)\lib\site-packages\ipykernel\__main__.py:1: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy if __name__ == ‘__main__’: D:\software\新建文件夹 (4)\lib\site-packages\ipykernel\__main__.py:2: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy from ipykernel import kernelapp as app D:\software\新建文件夹 (4)\lib\site-packages\ipykernel\__main__.py:3: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy app.launch_new_instance() D:\software\新建文件夹 (4)\lib\site-packages\ipykernel\__main__.py:4: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy D:\software\新建文件夹 (4)\lib\site-packages\ipykernel\__main__.py:5: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
为Pclass、Sex、Embarked、Parch、SibSp、Cabin变量创建虚拟变量
dummy_Pclass = pd.get_dummies(train.Pclass, prefix='Pclass')
dummy_Sex = pd.get_dummies(train.Sex, prefix='Sex')
dummy_Embarked = pd.get_dummies(train.Embarked, prefix='Embarked')
dummy_Parch = pd.get_dummies(train.Parch, prefix='Parch')
dummy_SibSp = pd.get_dummies(train.SibSp, prefix='SibSp')
dummy_Age = pd.get_dummies(train.age, prefix='Age')
dummy_Cabin = pd.get_dummies(train.Cabin, prefix='Cabin')
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, roc_curve,roc_auc_score,classification_report
划分训练集,将编号为0-623的乘客作为训练集。去除PassengerId和Name变量,添加常数项intercept. 因变量为乘客是否获救,自变量为乘客的票价、性别、登船码头、父母与子女数目、兄弟姐妹与配偶数目、年龄、船舱。除票价外,都为虚拟变量。考虑到Fare和Pclass之间的线性相关性,剔除Pclass变量。
train_y = train[:623]["Survived"]
cols_to_keep = ["Fare"]
train_x = train[:623][cols_to_keep].join(dummy_Sex.ix[:, "Sex_male":]).join(dummy_Embarked.ix[:,"Embarked_Q":]).join(dummy_Parch.ix[:,"Parch_1":]).join(dummy_SibSp.ix[:,"SibSp_1":]).join(dummy_Age.ix[:,"Age_(5, 15]":]).join(dummy_Cabin.ix[:,"Cabin_U" :])
train_x['intercept'] = 1.0
train_x.tail()
Fare | Sex_male | Embarked_Q | Embarked_S | Parch_1 | Parch_2 | SibSp_1 | SibSp_2 | Age_(5, 15] | Age_(15, 20] | Age_(20, 35] | Age_(35, 50] | Age_(50, 60] | Age_(60, 100] | Cabin_U | intercept | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
618 | 39.0000 | 0 | 0 | 1 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
619 | 10.5000 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 1 |
620 | 14.4542 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 1 |
621 | 52.5542 | 1 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 |
622 | 15.7417 | 1 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
对训练集构建逻辑斯蒂模型。
clf = LogisticRegression()
clf.fit(train_x,train_y)
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, max_iter=100, multi_class=’ovr’, n_jobs=1, penalty=’l2’, random_state=None, solver=’liblinear’, tol=0.0001, verbose=0, warm_start=False)
划分测试集,将编号为624-890的乘客作为测试集。
test_y = train[623:]["Survived"]
cols_to_keep = ["Fare"]
test_x = train[623:][cols_to_keep].join(dummy_Sex.ix[:, "Sex_male":]).join(dummy_Embarked.ix[:,"Embarked_Q":]).join(dummy_Parch.ix[:,"Parch_1":]).join(dummy_SibSp.ix[:,"SibSp_1":]).join(dummy_Age.ix[:,"Age_(5, 15]":]).join(dummy_Cabin.ix[:,"Cabin_U" :])
test_x['intercept'] = 1.0
test_x.head()
Fare | Sex_male | Embarked_Q | Embarked_S | Parch_1 | Parch_2 | SibSp_1 | SibSp_2 | Age_(5, 15] | Age_(15, 20] | Age_(20, 35] | Age_(35, 50] | Age_(50, 60] | Age_(60, 100] | Cabin_U | intercept | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
623 | 7.8542 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 1 |
624 | 16.1000 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 1 |
625 | 32.3208 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 |
626 | 12.3500 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 1 |
627 | 77.9583 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 |
利用测试集对模型进行测试
clf.predict(test_x)
array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,
0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1,
0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0,
0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0,
0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0,
1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1,
0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0], dtype=int64)
clf.predict_proba(test_x)
array([[ 0.86039834, 0.13960166],
[ 0.85711962, 0.14288038],
[ 0.74830885, 0.25169115],
[ 0.84502153, 0.15497847],
[ 0.12886629, 0.87113371],
[ 0.86038196, 0.13961804],
[ 0.81492416, 0.18507584],
[ 0.74973913, 0.25026087],
[ 0.88595841, 0.11404159],
[ 0.60609598, 0.39390402],
[ 0.86346251, 0.13653749],
[ 0.61197924, 0.38802076],
[ 0.35569986, 0.64430014],
[ 0.86037046, 0.13962954],
[ 0.79463014, 0.20536986],
[ 0.57570079, 0.42429921],
[ 0.83467975, 0.16532025],
[ 0.85798051, 0.14201949],
[ 0.10986907, 0.89013093],
[ 0.56445448, 0.43554552],
[ 0.84012224, 0.15987776],
[ 0.15820833, 0.84179167],
[ 0.56790934, 0.43209066],
[ 0.85796389, 0.14203611],
[ 0.65552637, 0.34447363],
[ 0.86051808, 0.13948192],
[ 0.35980506, 0.64019494],
[ 0.86038196, 0.13961804],
[ 0.29324608, 0.70675392],
[ 0.86017015, 0.13982985],
[ 0.28622388, 0.71377612],
[ 0.28287555, 0.71712445],
[ 0.8070549 , 0.1929451 ],
[ 0.86038196, 0.13961804],
[ 0.20682467, 0.79317533],
[ 0.85835972, 0.14164028],
[ 0.53882734, 0.46117266],
[ 0.80220769, 0.19779231],
[ 0.85539757, 0.14460243],
[ 0.69483827, 0.30516173],
[ 0.87933359, 0.12066641],
[ 0.83561804, 0.16438196],
[ 0.8070549 , 0.1929451 ],
[ 0.85835972, 0.14164028],
[ 0.86042952, 0.13957048],
[ 0.87914067, 0.12085933],
[ 0.11937883, 0.88062117],
[ 0.2853273 , 0.7146727 ],
[ 0.59808311, 0.40191689],
[ 0.90594814, 0.09405186],
[ 0.85835972, 0.14164028],
[ 0.86346251, 0.13653749],
[ 0.85801214, 0.14198786],
[ 0.86032122, 0.13967878],
[ 0.35349566, 0.64650434],
[ 0.52724747, 0.47275253],
[ 0.22879217, 0.77120783],
[ 0.28601744, 0.71398256],
[ 0.56939306, 0.43060694],
[ 0.85743204, 0.14256796],
[ 0.94208727, 0.05791273],
[ 0.82348607, 0.17651393],
[ 0.74901495, 0.25098505],
[ 0.94336392, 0.05663608],
[ 0.85705258, 0.14294742],
[ 0.85800384, 0.14199616],
[ 0.05587819, 0.94412181],
[ 0.5941366 , 0.4058634 ],
[ 0.18541791, 0.81458209],
[ 0.84012224, 0.15987776],
[ 0.83358085, 0.16641915],
[ 0.87933975, 0.12066025],
[ 0.88380587, 0.11619413],
[ 0.87914067, 0.12085933],
[ 0.28628812, 0.71371188],
[ 0.48214083, 0.51785917],
[ 0.70716252, 0.29283748],
[ 0.05715212, 0.94284788],
[ 0.65795297, 0.34204703],
[ 0.25709963, 0.74290037],
[ 0.81492001, 0.18507999],
[ 0.83837631, 0.16162369],
[ 0.8727473 , 0.1272527 ],
[ 0.3942793 , 0.6057207 ],
[ 0.69435145, 0.30564855],
[ 0.25955289, 0.74044711],
[ 0.76489311, 0.23510689],
[ 0.11637845, 0.88362155],
[ 0.65775927, 0.34224073],
[ 0.63734093, 0.36265907],
[ 0.85975561, 0.14024439],
[ 0.8839741 , 0.1160259 ],
[ 0.66714496, 0.33285504],
[ 0.07984609, 0.92015391],
[ 0.15579323, 0.84420677],
[ 0.81105308, 0.18894692],
[ 0.86042952, 0.13957048],
[ 0.24260674, 0.75739326],
[ 0.8360098 , 0.1639902 ],
[ 0.85835972, 0.14164028],
[ 0.87740579, 0.12259421],
[ 0.59721595, 0.40278405],
[ 0.85765731, 0.14234269],
[ 0.72253208, 0.27746792],
[ 0.2862853 , 0.7137147 ],
[ 0.83015243, 0.16984757],
[ 0.3208549 , 0.6791451 ],
[ 0.08720221, 0.91279779],
[ 0.79040078, 0.20959922],
[ 0.86346251, 0.13653749],
[ 0.85835972, 0.14164028],
[ 0.85835972, 0.14164028],
[ 0.85711962, 0.14288038],
[ 0.53746947, 0.46253053],
[ 0.24073084, 0.75926916],
[ 0.86038196, 0.13961804],
[ 0.86038196, 0.13961804],
[ 0.65520865, 0.34479135],
[ 0.61675979, 0.38324021],
[ 0.04187545, 0.95812455],
[ 0.83467975, 0.16532025],
[ 0.86037046, 0.13962954],
[ 0.63589211, 0.36410789],
[ 0.7945787 , 0.2054213 ],
[ 0.35569986, 0.64430014],
[ 0.5923993 , 0.4076007 ],
[ 0.8149159 , 0.1850841 ],
[ 0.18626833, 0.81373167],
[ 0.55494501, 0.44505499],
[ 0.85974901, 0.14025099],
[ 0.86038196, 0.13961804],
[ 0.26826856, 0.73173144],
[ 0.72096103, 0.27903897],
[ 0.86042134, 0.13957866],
[ 0.85651789, 0.14348211],
[ 0.86032122, 0.13967878],
[ 0.12575524, 0.87424476],
[ 0.8577608 , 0.1422392 ],
[ 0.87946251, 0.12053749],
[ 0.83078796, 0.16921204],
[ 0.09214503, 0.90785497],
[ 0.85801214, 0.14198786],
[ 0.1353402 , 0.8646598 ],
[ 0.81833156, 0.18166844],
[ 0.28627693, 0.71372307],
[ 0.77835462, 0.22164538],
[ 0.86019807, 0.13980193],
[ 0.85974901, 0.14025099],
[ 0.87920886, 0.12079114],
[ 0.18831656, 0.81168344],
[ 0.83358085, 0.16641915],
[ 0.56217041, 0.43782959],
[ 0.85802213, 0.14197787],
[ 0.59346021, 0.40653979],
[ 0.26217247, 0.73782753],
[ 0.81492209, 0.18507791],
[ 0.0820547 , 0.9179453 ],
[ 0.26297266, 0.73702734],
[ 0.11560723, 0.88439277],
[ 0.65520865, 0.34479135],
[ 0.7961241 , 0.2038759 ],
[ 0.86071471, 0.13928529],
[ 0.86063609, 0.13936391],
[ 0.35525524, 0.64474476],
[ 0.92489299, 0.07510701],
[ 0.71693682, 0.28306318],
[ 0.6076947 , 0.3923053 ],
[ 0.8149159 , 0.1850841 ],
[ 0.85057636, 0.14942364],
[ 0.6376244 , 0.3623756 ],
[ 0.822631 , 0.177369 ],
[ 0.86038196, 0.13961804],
[ 0.87740579, 0.12259421],
[ 0.17163368, 0.82836632],
[ 0.35894969, 0.64105031],
[ 0.83357894, 0.16642106],
[ 0.26194935, 0.73805065],
[ 0.85835972, 0.14164028],
[ 0.26062053, 0.73937947],
[ 0.42452126, 0.57547874],
[ 0.71744256, 0.28255744],
[ 0.86074419, 0.13925581],
[ 0.86042952, 0.13957048],
[ 0.71232894, 0.28767106],
[ 0.35504562, 0.64495438],
[ 0.87740579, 0.12259421],
[ 0.11900024, 0.88099976],
[ 0.86038523, 0.13961477],
[ 0.87341934, 0.12658066],
[ 0.85935323, 0.14064677],
[ 0.60934863, 0.39065137],
[ 0.86032122, 0.13967878],
[ 0.67707567, 0.32292433],
[ 0.35952192, 0.64047808],
[ 0.751824 , 0.248176 ],
[ 0.8796969 , 0.1203031 ],
[ 0.94539356, 0.05460644],
[ 0.10542774, 0.89457226],
[ 0.86007975, 0.13992025],
[ 0.88191682, 0.11808318],
[ 0.12684285, 0.87315715],
[ 0.93191133, 0.06808867],
[ 0.81531115, 0.18468885],
[ 0.84012224, 0.15987776],
[ 0.69813211, 0.30186789],
[ 0.8149159 , 0.1850841 ],
[ 0.18808428, 0.81191572],
[ 0.22676529, 0.77323471],
[ 0.71814943, 0.28185057],
[ 0.83357894, 0.16642106],
[ 0.86039834, 0.13960166],
[ 0.85780233, 0.14219767],
[ 0.0849913 , 0.9150087 ],
[ 0.86007975, 0.13992025],
[ 0.86032122, 0.13967878],
[ 0.84012224, 0.15987776],
[ 0.60672195, 0.39327805],
[ 0.85795222, 0.14204778],
[ 0.85692031, 0.14307969],
[ 0.29681064, 0.70318936],
[ 0.83393869, 0.16606131],
[ 0.85765731, 0.14234269],
[ 0.87931473, 0.12068527],
[ 0.95077511, 0.04922489],
[ 0.83327556, 0.16672444],
[ 0.8180723 , 0.1819277 ],
[ 0.08871631, 0.91128369],
[ 0.93364059, 0.06635941],
[ 0.90670657, 0.09329343],
[ 0.18814843, 0.81185157],
[ 0.11532953, 0.88467047],
[ 0.3446263 , 0.6553737 ],
[ 0.30260559, 0.69739441],
[ 0.20902316, 0.79097684],
[ 0.70727922, 0.29272078],
[ 0.49908055, 0.50091945],
[ 0.83357894, 0.16642106],
[ 0.85717847, 0.14282153],
[ 0.83675021, 0.16324979],
[ 0.17163368, 0.82836632],
[ 0.6376244 , 0.3623756 ],
[ 0.85835972, 0.14164028],
[ 0.39467085, 0.60532915],
[ 0.2731416 , 0.7268584 ],
[ 0.63987502, 0.36012498],
[ 0.85974901, 0.14025099],
[ 0.72317603, 0.27682397],
[ 0.86038196, 0.13961804],
[ 0.11238481, 0.88761519],
[ 0.67348135, 0.32651865],
[ 0.87880937, 0.12119063],
[ 0.2665907 , 0.7334093 ],
[ 0.26297533, 0.73702467],
[ 0.85718307, 0.14281693],
[ 0.85796389, 0.14203611],
[ 0.86038196, 0.13961804],
[ 0.10513251, 0.89486749],
[ 0.29535416, 0.70464584],
[ 0.86038196, 0.13961804],
[ 0.35756781, 0.64243219],
[ 0.85935323, 0.14064677],
[ 0.86071471, 0.13928529],
[ 0.50077687, 0.49922313],
[ 0.85835972, 0.14164028],
[ 0.14507275, 0.85492725],
[ 0.26239326, 0.73760674],
[ 0.60648726, 0.39351274],
[ 0.8149159 , 0.1850841 ]])
preds = clf.predict(test_x)
计算模型的混淆矩阵如下所示。
confusion_matrix(test_y,preds)
array([[157, 15],
[ 35, 61]])
计算模型的ROC/AUC得分,并画出ROC曲线。模型的ROC/AUC得分为0.88,表明预测准确的概率为88%左右。模型预测结果较好。
pre = clf.predict_proba(test_x)
roc_auc_score(test_y,pre[:,1])
0.88114704457364346
fpr,tpr,thresholds = roc_curve(test_y,pre[:,1])
fig,ax = plt.subplots(figsize=(8,5))
plt.plot(fpr,tpr)
ax.set_title("Roc of Logistic Regression")
<matplotlib.text.Text at 0x7674a1f588>
模型预测结果分类报告如下所示。
print(classification_report(test_y,preds))
precision recall f1-score support
0 0.82 0.91 0.86 172
1 0.80 0.64 0.71 96
avg / total 0.81 0.81 0.81 268
总体而言,模型的拟合结果较好。