逻辑回归python实现

文章目录

    • Sigmod函数
    • 预测函数
    • 代价函数
    • 损失函数
    • 最速下降
    • 代码实现(python)
    • 数据集

Sigmod函数

g ( z ) = 1 1 + e − z g(z)=\frac{1}{1+e^{-z}} g(z)=1+ez1

逻辑回归python实现_第1张图片

预测函数

h θ ( x ) = g ( θ T x ) h_{\theta}(x)=g(\theta^{T}x) hθ(x)=g(θTx)

代价函数

c o s t ( h θ ( x ) , y ) = { − l o g ( h θ ( x ) ) , if  y = 1 − l o g ( 1 − h θ ( x ) ) , if  y = 0 cost(h_{\theta}(x),y)= \begin{cases} -log(h_{\theta}(x)), & \text {if $y=1$} \\ -log(1-h_{\theta}(x)), & \text{if $y=0$} \end{cases} cost(hθ(x),y)={log(hθ(x)),log(1hθ(x)),if y=1if y=0

如果标签为1,预测值越大则损失越小

如果标签为0,预测值越小则损失越大

分别对应上面两个函数

c o s t ( h θ ( x ) , y ) = − y l o g ( h θ ( x i ) ) − ( 1 − y ) l o g ( 1 − h θ ( x i ) ) cost(h_{\theta}(x),y)=-ylog(h_{\theta}(x^i))-(1-y)log(1-h_{\theta}(x^i)) cost(hθ(x),y)=ylog(hθ(xi))(1y)log(1hθ(xi))

损失函数

J ( θ ) = − 1 m ∑ i = 1 m c o s t ( h θ ( x ) , y ) J(\theta)=-\frac{1}{m}\sum_{i=1}^{m} cost(h_{\theta}(x),y) J(θ)=m1i=1mcost(hθ(x),y)

最速下降

θ j = θ j − α m j ′ ( θ ) 即 θ j = θ j − α m ∑ i = 1 m ( h θ ( x i ) − y i ) x j i \theta_{j}=\theta_{j}-\frac{\alpha}{m} j^{'}(\theta)\\ 即\theta_{j}=\theta_{j}-\frac{\alpha}{m}\sum_{i=1}^{m}(h_{\theta}(x^{i})-y^{i})x_{j}^i θj=θjmαj(θ)θj=θjmαi=1m(hθ(xi)yi)xji

− j ′ ( θ ) 保障损失函数始终处于下降 -j^{'}(\theta)保障损失函数始终处于下降 j(θ)保障损失函数始终处于下降

代码实现(python)

import matplotlib
import matplotlib.pyplot as plt
import csv
import numpy as np
import math


def loadDataset():
    data = []
    labels = []
    with open('logisticDataset.txt', 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            data.append([1.0, float(row[0]), float(row[1])])
            labels.append(int(row[2]))
    return data, labels


def plotBestFit(W):
    # 把训练集数据用坐标的形式画出来
    dataMat, labelMat = loadDataset()
    dataArr = np.array(dataMat)
    n = np.shape(dataArr)[0]
    xcord1 = []
    ycord1 = []
    xcord2 = []
    ycord2 = []
    for i in range(n):
        if int(labelMat[i]) == 1:
            xcord1.append(dataArr[i, 1])
            ycord1.append(dataArr[i, 2])
        else:
            xcord2.append(dataArr[i, 1])
            ycord2.append(dataArr[i, 2])
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(xcord1, ycord1, s=30, c='red', marker='s')
    ax.scatter(xcord2, ycord2, s=30, c='green')

    # 把分类边界画出来
    x = np.arange(-3.0, 3.0, 0.1)
    y = (-W[0] - W[1] * x) / W[2]
    ax.plot(x, y)
    plt.show()


def plotloss(loss_list):
    x = np.arange(0, 30, 0.01)
    plt.plot(x, np.array(loss_list), label='linear')

    plt.xlabel('time')  # 梯度下降的次数
    plt.ylabel('loss')  # 损失值
    plt.title('loss trend')  # 损失值随着W不断更新,不断变化的趋势
    plt.legend()  # 图形图例
    plt.show()


def main():
    # 读取训练集(txt文件)中的数据,
    data, labels = loadDataset()
    # 将数据转换成矩阵的形式,便于后面进行计算
    # 构建特征矩阵X
    X = np.array(data)
    # 构建标签矩阵y
    y = np.array(labels).reshape(-1, 1)
    # 随机生成一个w参数(权重)矩阵    .reshape((-1,1))的作用是,不知道有多少行,只想变成一列
    W = 0.001 * np.random.randn(3, 1).reshape((-1, 1))
    # m表示一共有多少组训练数据
    m = len(X)
    # 定义梯度下降的学习率 0.03
    learn_rate = 0.03

    loss_list = []
    # 实现梯度下降算法,不断更新W,获得最优解,使损失函数的损失值最小
    for i in range(3000):
        # 最重要的就是这里用numpy 矩阵计算,完成假设函数计算,损失函数计算,梯度下降计算
        # 计算假设函数 h(w)x
        g_x = np.dot(X, W)
        h_x = 1 / (1 + np.exp(-g_x))

        # 计算损失函数 Cost Function 的损失值loss
        loss = np.log(h_x) * y + (1 - y) * np.log(1 - h_x)
        loss = -np.sum(loss) / m
        loss_list.append(loss)

        # 梯度下降函数更新W权重
        dW = X.T.dot(h_x - y) / m
        W += -learn_rate * dW

    # 得到更新后的W,可视化
    print('W最优解:')
    print(W)
    print('最终得到的分类边界:')
    plotBestFit(W)
    print('损失值随着W不断更新,不断变化的趋势:')
    plotloss(loss_list)

    # 定义一个测试数据,计算他属于那一类别
    test_x = np.array([1, -1.395634, 4.662541])
    test_y = 1 / (1 + np.exp(-np.dot(test_x, W)))
    print(test_y)


#     print(data_arr)
if __name__ == '__main__':
    main()

逻辑回归python实现_第2张图片

逻辑回归python实现_第3张图片
逻辑回归python实现_第4张图片

数据集

-0.017612   14.053064  0
-1.395634  4.662541   1
-0.752157  6.538620   0
-1.322371  7.152853   0
0.423363   11.054677  0
0.406704   7.067335   1
0.667394   12.741452  0
-2.460150  6.866805   1
0.569411   9.548755   0
-0.026632  10.427743  0
0.850433   6.920334   1
1.347183   13.175500  0
1.176813   3.167020   1
-1.781871  9.097953   0
-0.566606  5.749003   1
0.931635   1.589505   1
-0.024205  6.151823   1
-0.036453  2.690988   1
-0.196949  0.444165   1
1.014459   5.754399   1
1.985298   3.230619   1
-1.693453  -0.557540  1
-0.576525  11.778922  0
-0.346811  -1.678730  1
-2.124484  2.672471   1
1.217916   9.597015   0
-0.733928  9.098687   0
-3.642001  -1.618087  1
0.315985   3.523953   1
1.416614   9.619232   0
-0.386323  3.989286   1
0.556921   8.294984   1
1.224863   11.587360  0
-1.347803  -2.406051  1
1.196604   4.951851   1
0.275221   9.543647   0
0.470575   9.332488   0
-1.889567  9.542662   0
-1.527893  12.150579  0
-1.185247  11.309318  0
-0.445678  3.297303   1
1.042222   6.105155   1
-0.618787  10.320986  0
1.152083   0.548467   1
0.828534   2.676045   1
-1.237728  10.549033  0
-0.683565  -2.166125  1
0.229456   5.921938   1
-0.959885  11.555336  0
0.492911   10.993324  0
0.184992   8.721488   0
-0.355715  10.325976  0
-0.397822  8.058397   0
0.824839   13.730343  0
1.507278   5.027866   1
0.099671   6.835839   1
-0.344008  10.717485  0
1.785928   7.718645   1
-0.918801  11.560217  0
-0.364009  4.747300   1
-0.841722  4.119083   1
0.490426   1.960539   1
-0.007194  9.075792   0
0.356107   12.447863  0
0.342578   12.281162  0
-0.810823  -1.466018  1
2.530777   6.476801   1
1.296683   11.607559  0
0.475487   12.040035  0
-0.783277  11.009725  0
0.074798   11.023650  0
-1.337472  0.468339   1
-0.102781  13.763651  0
-0.147324  2.874846   1
0.518389   9.887035   0
1.015399   7.571882   0
-1.658086  -0.027255  1
1.319944   2.171228   1
2.056216   5.019981   1
-0.851633  4.375691   1
-1.510047  6.061992   0
-1.076637  -3.181888  1
1.821096   10.283990  0
3.010150   8.401766   1
-1.099458  1.688274   1
-0.834872  -1.733869  1
-0.846637  3.849075   1
1.400102   12.628781  0
1.752842   5.468166   1
0.078557   0.059736   1
0.089392   -0.715300  1
1.825662   12.693808  0
0.197445   9.744638   0
0.126117   0.922311   1
-0.679797  1.220530   1
0.677983   2.556666   1
0.761349   10.693862  0
-2.168791  0.143632   1
1.388610   9.341997   0
0.317029   14.739025  0

你可能感兴趣的:(python,逻辑回归)