感知机 @ Python

感知机(二分类问题) @ Python

  • M 是 误分类点的集合
  • 损失函数 :
    • minw,bL(w,b)=xiMyi(wxi+b)
  • 损失函数的梯度 :
    • wL(w,b)=xiMyixi
    • bL(w,b)=xiMyi
  • 采用随机梯度下降法, 随机选取一个误分类点对w, b进行更新
    • w=w(αyixi)
    • b=b(αyi)
# _*_ coding:utf-8 _*_
import numpy as np
import matplotlib.pyplot as plt


class Perceptron:
    def __init__(self, x, y=1):
        self.x = x
        self.y = y
        self.w = np.ones((self.x.shape[1], 1)) / 10.0  # 提取x的第2维度的大小, 生成一个n X 1 的0.1矩阵
        self.b = 0.0  # 偏置项
        self.a = 1  # 改变此处可以得到不同的平面

    def train(self):
        length = self.x.shape[0]
        while True:
            count = 0  # 记录误分类点的数目
            for i in range(length):
                y = np.dot(self.x[i], self.w) + self.b
                # 如果是误分类点, 0 恰好在平面上
                if y * self.y[i] <= 0:
                    self.w = self.w + (self.a * self.y[i] * self.x[i]).reshape(self.w.shape)
                    self.b = self.b + self.a * self.y[i]
                    count += 1
            if count == 0:
                return self.w, self.b


class ShowPicture:
    def __init__(self, x, y, w, b):
        self.b = b
        self.w = w
        plt.figure(1)
        plt.title('what the fuck', size=14)
        plt.xlabel('x-axis', size=14)
        plt.ylabel('y-axis', size=14)

        xData = np.linspace(0, 5, 100)  # 创建等差数组
        yData = self.expression(xData)
        plt.plot(xData, yData, color='r', label='y1 data')

        # 绘制散点图
        for i in range(x.shape[0]):
            if y[i] < 0:
                plt.scatter(x[i][0], x[i][1], marker='x', s=50)
            else:
                plt.scatter(x[i][0], x[i][1], s=50)
        plt.savefig('2d.png', dpi=75)

    def expression(self, x):
        y = (-self.b - self.w[0] * x) / self.w[1]
        return y

    def show(self):
        plt.show()


xArray = np.array([[3, 3], [4, 3], [1, 1]])
yArray = np.array([1, 1, -1])
# [[3 3]
# [4 3]
# [1 1]]
p = Perceptron(x=xArray, y=yArray)
w, b = p.train()
s = ShowPicture(x=xArray, y=yArray, w=w, b=b)
s.show()

感知机 @ Python_第1张图片

你可能感兴趣的:(统计学习方法)