import numpy as np
def load_data(file_name):
''' 数据导入函数
input: file-name(string)的训练数据的位置
output: feature_data(mat)特征
label_data(mat)标签
'''
returnMat=np.loadtxt(file_name,delimiter='\t')
returnMat=np.mat(returnMat)
label_data=returnMat[:,-1]
feature_data=np.delete(returnMat,-1,axis=1)
a=np.ones((feature_data.shape[0],1))
feature_data=np.c_[feature_data,a]
return feature_data,label_data
3.定义各种要使用的函数计算
def sig(x):
''' sigmoid函数
input: X(mat) :feature_data*w
output: sigmoid(x)(mat):sigmoid的值
'''
sigmoidx=1/(1+np.exp(-x))
sigmoidx=np.mat(sigmoidx)
return sigmoidx
def lr_train_bgd(feature_data,label_data,maxCycle,alpha):
''' 利用梯度下降法训练LR模型
input: feature_data(mat)样例数据
label_data(mat)标签数据
maxCyxle(int)最大迭代次数
alpha(float)学习率
'''
s=feature_data.shape[1]
sta=np.mat(np.ones((s,1)))
i=0
while i
def model(w,feature_data):
'''模型计算预测值函数
input:w(mat)模型权重
feature_data(mat) 训练数据特征
output:h(mat) 预测值
'''
x=feature_data*w
h=sig(x)
g=h.shape[0]
for i in range(0,g):
if h[i]>0.5:
h[i]=1
elif h[i]<0.5:
h[i]=0
return h
def error_rate(h,label_data):
''' 计算当前的损失函数正确率
input: h(mat)预测值
label_data 实际值
output: error(float)错误率
'''
thr=0
numbel=label_data.shape[0]
for i in range(0,numbel):
if h[i]==label_data[i]:
thr=thr+1
error=thr/numbel
return error,numbel
if __name__ == "__main__":
# 1.导入训练数据
print("-----1.load data-----")
feature_data,label_data=load_data("horseColicTraining.txt")
text_data,text_label_data=load_data("horseColicTest.txt")
#2.训练LR模型
print("-----2.training------")
w= lr_train_bgd(feature_data,label_data,10000,0.000001)
#3.保存最终模型
# print("-----3.save model----")
#save_model("weights",w)
#3.构建最终模型,输出模型预测值
print("------3.model---------")
h=model(w, text_data)
#4.输出误差率
print("-------4.error value-----")
error,numbe=error_rate(h,text_label_data)
print(error)
print(text_label_data)
print(h)
print(w)
#print(feature_data.shape[1])
#print(label_data)
#5.预测数值
# print("--------5.furture value-----")
# h1=model(w,text_data)
后记:这是根据周志华西瓜书上的公式进行推导后,进行的代码实现,是对数几率的简单实现,没有加入正则化等内容,对于初学者来说会简单一点,适合入门学习,数据集我会在后边给出。要是需要对数几率详细的数学推导,可在留言处评论,希望大家共同学习。
数据集下载地址:二分类数据集下载地址,点这里下载