python单层感知机训练过程绘制(动态绘制)

以简单的单层感知机为例,展示动态绘制训练过程的方法。

代码如下:        

import numpy as np
import matplotlib.pyplot as plt


def Count_y(x, N, w, b):
    y_ = 0
    for i in range(N):  # 计算预测值
        y_ += w[i] * x[i]
    y_ += b  # 添加偏置
    return y_


def Generate_Dataset():  # 生成数据集,
    X = np.random.randn(50, 2)  # 生成50行2列的数据集
    Y = np.array([1 if 5 * i[0] - 3 * i[1] +2 >= 0 else -1 for i in X])  # 填写标签
    data_cole = ["g" if i > 0 else 'b' for i in Y]
    return X, Y, data_cole


def Train(X, Y, data_cole, N=2, lr=0.003, num_epochs=100):  # 样本,标签,学习率和循环轮数
    w = np.random.rand(N)  # 数据维数为2
    b = 0
    plt_x = []  # 画图用
    w_all = []
    b_all = []
    for i in range(N):
        plt_x.append([data[i] for data in X])
    x_max = [min(plt_x[0]), max(plt_x[0])]  #
    for j in range(num_epochs):  # 循环
        stop = True  # 停止标记,没有分类错误的就停止
        for i in range(Y.size):  # 一轮
            x = X[i]
            y = Y[i]
            y_ = 1 if Count_y(x, N, w, b)>=0 else -1  # 计算预测值
            if y_ * y <= 0:  # 标记错误
                w_all.append([w[0],w[1]])
                b_all.append(b)
                update = lr*(y-y_)
                w += update * x  # 更新权重
                b += update  # 更新偏置
                stop = False
        if stop:
            break
    w_all.append(w)
    b_all.append(b)
    Draw(plt_x, x_max, N, data_cole, w_all, b_all)  # 绘图
    # print(w,b)


def Draw(plt_x, x_max, N, data_cole, w_all, b_all):  # 绘图
    plt.figure()
    for i in range(np.size(b_all)):
        # 清除上次绘图
        plt.clf()
        # 设置显示范围
        plt.xlim(x_max[0], x_max[1])
        plt.ylim(x_max[0], x_max[1])
        w= w_all[i]
        b = b_all[i]
        print(w,b)
        plt.scatter(plt_x[0], plt_x[1], c=data_cole, alpha=0.8)  # 绘点散点
        # 绘制分类线
        y_ = [(-w[0] * i - b) / w[1] for i in x_max]
        plt.plot(x_max, y_, 'r')
        # 刷新图形
        plt.draw()
        # 等待0.05s
        plt.pause(0.05)
    plt.show()  # 遍历完成后不消失



X, Y, data_cole = Generate_Dataset()
Train(X, Y, data_cole, )

运行结果:

感知机训练过程展示

注意事项:

对w,b的保存注意不能直接w_all.append(w),注意拷贝影响。

你可能感兴趣的:(python,人工智能,算法)