g ( z ) = 1 1 + e − z g(z)=\frac{1}{1+e^{-z}} g(z)=1+e−z1
h θ ( x ) = g ( θ T x ) h_{\theta}(x)=g(\theta^{T}x) hθ(x)=g(θTx)
c o s t ( h θ ( x ) , y ) = { − l o g ( h θ ( x ) ) , if y = 1 − l o g ( 1 − h θ ( x ) ) , if y = 0 cost(h_{\theta}(x),y)= \begin{cases} -log(h_{\theta}(x)), & \text {if $y=1$} \\ -log(1-h_{\theta}(x)), & \text{if $y=0$} \end{cases} cost(hθ(x),y)={−log(hθ(x)),−log(1−hθ(x)),if y=1if y=0
如果标签为1,预测值越大则损失越小
如果标签为0,预测值越小则损失越大
分别对应上面两个函数
c o s t ( h θ ( x ) , y ) = − y l o g ( h θ ( x i ) ) − ( 1 − y ) l o g ( 1 − h θ ( x i ) ) cost(h_{\theta}(x),y)=-ylog(h_{\theta}(x^i))-(1-y)log(1-h_{\theta}(x^i)) cost(hθ(x),y)=−ylog(hθ(xi))−(1−y)log(1−hθ(xi))
J ( θ ) = − 1 m ∑ i = 1 m c o s t ( h θ ( x ) , y ) J(\theta)=-\frac{1}{m}\sum_{i=1}^{m} cost(h_{\theta}(x),y) J(θ)=−m1i=1∑mcost(hθ(x),y)
θ j = θ j − α m j ′ ( θ ) 即 θ j = θ j − α m ∑ i = 1 m ( h θ ( x i ) − y i ) x j i \theta_{j}=\theta_{j}-\frac{\alpha}{m} j^{'}(\theta)\\ 即\theta_{j}=\theta_{j}-\frac{\alpha}{m}\sum_{i=1}^{m}(h_{\theta}(x^{i})-y^{i})x_{j}^i θj=θj−mαj′(θ)即θj=θj−mαi=1∑m(hθ(xi)−yi)xji
− j ′ ( θ ) 保障损失函数始终处于下降 -j^{'}(\theta)保障损失函数始终处于下降 −j′(θ)保障损失函数始终处于下降
import matplotlib
import matplotlib.pyplot as plt
import csv
import numpy as np
import math
def loadDataset():
data = []
labels = []
with open('logisticDataset.txt', 'r') as f:
reader = csv.reader(f, delimiter='\t')
for row in reader:
data.append([1.0, float(row[0]), float(row[1])])
labels.append(int(row[2]))
return data, labels
def plotBestFit(W):
# 把训练集数据用坐标的形式画出来
dataMat, labelMat = loadDataset()
dataArr = np.array(dataMat)
n = np.shape(dataArr)[0]
xcord1 = []
ycord1 = []
xcord2 = []
ycord2 = []
for i in range(n):
if int(labelMat[i]) == 1:
xcord1.append(dataArr[i, 1])
ycord1.append(dataArr[i, 2])
else:
xcord2.append(dataArr[i, 1])
ycord2.append(dataArr[i, 2])
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
ax.scatter(xcord2, ycord2, s=30, c='green')
# 把分类边界画出来
x = np.arange(-3.0, 3.0, 0.1)
y = (-W[0] - W[1] * x) / W[2]
ax.plot(x, y)
plt.show()
def plotloss(loss_list):
x = np.arange(0, 30, 0.01)
plt.plot(x, np.array(loss_list), label='linear')
plt.xlabel('time') # 梯度下降的次数
plt.ylabel('loss') # 损失值
plt.title('loss trend') # 损失值随着W不断更新,不断变化的趋势
plt.legend() # 图形图例
plt.show()
def main():
# 读取训练集(txt文件)中的数据,
data, labels = loadDataset()
# 将数据转换成矩阵的形式,便于后面进行计算
# 构建特征矩阵X
X = np.array(data)
# 构建标签矩阵y
y = np.array(labels).reshape(-1, 1)
# 随机生成一个w参数(权重)矩阵 .reshape((-1,1))的作用是,不知道有多少行,只想变成一列
W = 0.001 * np.random.randn(3, 1).reshape((-1, 1))
# m表示一共有多少组训练数据
m = len(X)
# 定义梯度下降的学习率 0.03
learn_rate = 0.03
loss_list = []
# 实现梯度下降算法,不断更新W,获得最优解,使损失函数的损失值最小
for i in range(3000):
# 最重要的就是这里用numpy 矩阵计算,完成假设函数计算,损失函数计算,梯度下降计算
# 计算假设函数 h(w)x
g_x = np.dot(X, W)
h_x = 1 / (1 + np.exp(-g_x))
# 计算损失函数 Cost Function 的损失值loss
loss = np.log(h_x) * y + (1 - y) * np.log(1 - h_x)
loss = -np.sum(loss) / m
loss_list.append(loss)
# 梯度下降函数更新W权重
dW = X.T.dot(h_x - y) / m
W += -learn_rate * dW
# 得到更新后的W,可视化
print('W最优解:')
print(W)
print('最终得到的分类边界:')
plotBestFit(W)
print('损失值随着W不断更新,不断变化的趋势:')
plotloss(loss_list)
# 定义一个测试数据,计算他属于那一类别
test_x = np.array([1, -1.395634, 4.662541])
test_y = 1 / (1 + np.exp(-np.dot(test_x, W)))
print(test_y)
# print(data_arr)
if __name__ == '__main__':
main()
-0.017612 14.053064 0
-1.395634 4.662541 1
-0.752157 6.538620 0
-1.322371 7.152853 0
0.423363 11.054677 0
0.406704 7.067335 1
0.667394 12.741452 0
-2.460150 6.866805 1
0.569411 9.548755 0
-0.026632 10.427743 0
0.850433 6.920334 1
1.347183 13.175500 0
1.176813 3.167020 1
-1.781871 9.097953 0
-0.566606 5.749003 1
0.931635 1.589505 1
-0.024205 6.151823 1
-0.036453 2.690988 1
-0.196949 0.444165 1
1.014459 5.754399 1
1.985298 3.230619 1
-1.693453 -0.557540 1
-0.576525 11.778922 0
-0.346811 -1.678730 1
-2.124484 2.672471 1
1.217916 9.597015 0
-0.733928 9.098687 0
-3.642001 -1.618087 1
0.315985 3.523953 1
1.416614 9.619232 0
-0.386323 3.989286 1
0.556921 8.294984 1
1.224863 11.587360 0
-1.347803 -2.406051 1
1.196604 4.951851 1
0.275221 9.543647 0
0.470575 9.332488 0
-1.889567 9.542662 0
-1.527893 12.150579 0
-1.185247 11.309318 0
-0.445678 3.297303 1
1.042222 6.105155 1
-0.618787 10.320986 0
1.152083 0.548467 1
0.828534 2.676045 1
-1.237728 10.549033 0
-0.683565 -2.166125 1
0.229456 5.921938 1
-0.959885 11.555336 0
0.492911 10.993324 0
0.184992 8.721488 0
-0.355715 10.325976 0
-0.397822 8.058397 0
0.824839 13.730343 0
1.507278 5.027866 1
0.099671 6.835839 1
-0.344008 10.717485 0
1.785928 7.718645 1
-0.918801 11.560217 0
-0.364009 4.747300 1
-0.841722 4.119083 1
0.490426 1.960539 1
-0.007194 9.075792 0
0.356107 12.447863 0
0.342578 12.281162 0
-0.810823 -1.466018 1
2.530777 6.476801 1
1.296683 11.607559 0
0.475487 12.040035 0
-0.783277 11.009725 0
0.074798 11.023650 0
-1.337472 0.468339 1
-0.102781 13.763651 0
-0.147324 2.874846 1
0.518389 9.887035 0
1.015399 7.571882 0
-1.658086 -0.027255 1
1.319944 2.171228 1
2.056216 5.019981 1
-0.851633 4.375691 1
-1.510047 6.061992 0
-1.076637 -3.181888 1
1.821096 10.283990 0
3.010150 8.401766 1
-1.099458 1.688274 1
-0.834872 -1.733869 1
-0.846637 3.849075 1
1.400102 12.628781 0
1.752842 5.468166 1
0.078557 0.059736 1
0.089392 -0.715300 1
1.825662 12.693808 0
0.197445 9.744638 0
0.126117 0.922311 1
-0.679797 1.220530 1
0.677983 2.556666 1
0.761349 10.693862 0
-2.168791 0.143632 1
1.388610 9.341997 0
0.317029 14.739025 0