BP经典入门算法实例—鸢尾花的分类(Python)

       Iris数据集(鸢尾花数据集下载,密码:ae1e)是常用的分类实验数据集,由Fisher,1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。
iris以鸢尾花的特征作为数据来源,常用在分类操作中。该数据集由3种不同类型的鸢尾花的50个样本数据构成。其中的一个种类与另外两个种类是线性可分离的,后两个种类是非线性可分离的。

该数据集包含了4个属性:
        Sepal.Length(花萼长度),单位是cm;
        Sepal.Width(花萼宽度),单位是cm;
        Petal.Length(花瓣长度),单位是cm;
        Petal.Width(花瓣宽度),单位是cm;
种类:Iris Setosa(1.山鸢尾)、Iris Versicolour(2.杂色鸢尾),以及Iris Virginica(3.维吉尼亚鸢尾)。

 

BP算法原理:

       BP算法实质是求取误差函数的最小值问题。这种算法采用非线性规划中的最速下降方法,按误差函数的负梯度方向修改权系数。 为了说明BP算法,首先定义误差函数e。取期望输出和实际输出之差的平方和为误差函数,则有:

                                                               

具体参考博客:https://blog.csdn.net/zhouchengyunew/article/details/6267193

源代码:

from __future__ import division #python3中不需要这句了
import math
import random
import pandas as pd
 
 
flowerLables = {0: 'Iris-setosa',
                1: 'Iris-versicolor',
                2: 'Iris-virginica'}
 
random.seed(0)
 
 
# 生成区间[a, b)内的随机数
def rand(a, b):
    return (b - a) * random.random() + a
  
# 生成大小 I*J 的矩阵,默认零矩阵
def makeMatrix(I, J, fill=0.0):
    m = []
    for i in range(I):
        m.append([fill] * J)
    return m
 
# 函数 sigmoid
def sigmoid(x):
    return 1.0 / (1.0 + math.exp(-x))
  
# 函数 sigmoid 的导数
def dsigmoid(x):
    return x * (1 - x)

# 定义神经网络类  
class NN:
    """ 三层反向传播神经网络 """
 
    def __init__(self, ni, nh, no):

        # 输入层、隐藏层、输出层的节点(数)
        self.ni = ni + 1  # 增加一个偏差节点bias
        self.nh = nh + 1
        self.no = no
 
        # 激活神经网络的所有节点(向量)
        self.ai = [1.0] * self.ni
        self.ah = [1.0] * self.nh
        self.ao = [1.0] * self.no
 
        # 建立权重(矩阵)
        self.wi = makeMatrix(self.ni, self.nh)
        self.wo = makeMatrix(self.nh, self.no)

        # 设为随机值
        for i in range(self.ni):
            for j in range(self.nh):
                self.wi[i][j] = rand(-0.2, 0.2) #生成[-0.2,0.2]之间的随机数
        for j in range(self.nh):
            for k in range(self.no):
                self.wo[j][k] = rand(-2, 2)  ##生成[-2,2]之间的随机数
 
    def update(self, inputs):
        if len(inputs) != self.ni - 1:
            raise ValueError('与输入层节点数不符!')
 
        # 激活输入层
        for i in range(self.ni - 1):
            self.ai[i] = inputs[i]
 
        # 激活隐藏层
        for j in range(self.nh):
            sum = 0.0
            for i in range(self.ni):
                sum = sum + self.ai[i] * self.wi[i][j]
            self.ah[j] = sigmoid(sum)
 
        # 激活输出层
        for k in range(self.no):
            sum = 0.0
            for j in range(self.nh):
                sum = sum + self.ah[j] * self.wo[j][k]
            self.ao[k] = sigmoid(sum)
 
        return self.ao[:]
 
    def backPropagate(self, targets, lr):
        """ 反向传播 """
 
        # 计算输出层的误差
        output_deltas = [0.0] * self.no
        for k in range(self.no):
            error = targets[k] - self.ao[k]
            output_deltas[k] = dsigmoid(self.ao[k]) * error
 
        # 计算隐藏层的误差
        hidden_deltas = [0.0] * self.nh
        for j in range(self.nh):
            error = 0.0
            for k in range(self.no):
                error = error + output_deltas[k] * self.wo[j][k]
            hidden_deltas[j] = dsigmoid(self.ah[j]) * error
 
        # 更新输出层权重
        for j in range(self.nh):
            for k in range(self.no):
                change = output_deltas[k] * self.ah[j]
                self.wo[j][k] = self.wo[j][k] + lr * change
 
        # 更新输入层权重
        for i in range(self.ni):
            for j in range(self.nh):
                change = hidden_deltas[j] * self.ai[i]
                self.wi[i][j] = self.wi[i][j] + lr * change
 
        # 计算误差
        error = 0.0

        ''' 取期望输出和实际输出之差的平方和为误差函数'''
        error += 0.5 * (targets[k] - self.ao[k]) ** 2  #平方误差函数
        return error
 
    def test(self, patterns):
        count = 0
        for p in patterns:
            #原始类别
            target = flowerLables[(p[1].index(1))]
            result = self.update(p[0])

            #最大值的索引即为预测的类别flowerLables[index]
            index = result.index(max(result))
            print(p[0], ':', target, '->', flowerLables[index])
            
            #预测类别和原始类别相同时加1
            count += (target == flowerLables[index])
        
        #计算测试准确率
        accuracy = float(count / len(patterns))
        print('accuracy: %-.9f' % accuracy)
 
    def weights(self):
        print('输入层权重:')
        for i in range(self.ni):
            print(self.wi[i])
        print()
        print('输出层权重:')
        for j in range(self.nh):
            print(self.wo[j])
 
    def train(self, patterns, iterations=1000, lr=0.1):
        # lr: 学习速率(learning rate)
        for i in range(iterations):
            error = 0.0
            for p in patterns:
                inputs = p[0]
                targets = p[1]
                self.update(inputs)
                error = error + self.backPropagate(targets, lr)
            #每隔100次输出一次误差
            if i % 100 == 0:
                print('error: %-.9f' % error)
 
 
 
def iris():
    data = []
    # 读取数据
    raw = pd.read_csv('iris.csv')
    raw_data = raw.values
    raw_feature = raw_data[0:, 0:4]

    #将最后一列的鸢尾花类别转成one-hot编码形式
    for i in range(len(raw_feature)):
        ele = []
        ele.append(list(raw_feature[i]))
        if raw_data[i][4] == 'Iris-setosa':
            ele.append([1, 0, 0])
        elif raw_data[i][4] == 'Iris-versicolor':
            ele.append([0, 1, 0])
        else:
            ele.append([0, 0, 1])
        data.append(ele)
    # 随机打乱数据
    random.shuffle(data)
    # 选取打乱后的前100个作为训练数据
    training = data[0:100]
    # 选取打乱后的后50个作为测试数据
    test = data[101:]
    #输入层4个节点,隐藏层7个,输出层3个(100,010,001三类)
    nn = NN(4, 7, 3)
    # 训练网络,轮10000次
    nn.train(training, iterations=10000)
    #测试数据
    nn.test(test)


''' 
if __name__ == '__main__':  的作用
   当这个.py文件被直接运行时,if __name__ == '__main__'之下的代码块将被运行
   当这个.py文件以模块形式被导入时,if __name__ == '__main__'之下的代码块不被运行
'''
if __name__ == '__main__':
    iris()

训练结果:

error: 2.769657180
error: 0.007947966
error: 0.003326474
error: 0.002045720
error: 0.001470802
error: 0.001146268
error: 0.000938282
error: 0.000793788
error: 0.000687639
error: 0.000606396
error: 0.000542234
error: 0.000490292
error: 0.000447389
error: 0.000411360
error: 0.000380678
error: 0.000354238
error: 0.000331219
error: 0.000310998
error: 0.000293096
error: 0.000277135
error: 0.000262817
error: 0.000249901
error: 0.000238191
error: 0.000227525
error: 0.000217771
error: 0.000208817
error: 0.000200568
error: 0.000192943
error: 0.000185876
error: 0.000179306
error: 0.000173184
error: 0.000167465
error: 0.000162111
error: 0.000157087
error: 0.000152365
error: 0.000147917
error: 0.000143722
error: 0.000139757
error: 0.000136004
error: 0.000132447
error: 0.000129071
error: 0.000125863
error: 0.000122810
error: 0.000119901
error: 0.000117126
error: 0.000114476
error: 0.000111944
error: 0.000109521
error: 0.000107200
error: 0.000104976
error: 0.000102841
error: 0.000100792
error: 0.000098822
error: 0.000096928
error: 0.000095105
error: 0.000093349
error: 0.000091657
error: 0.000090024
error: 0.000088449
error: 0.000086928
error: 0.000085458
error: 0.000084037
error: 0.000082662
error: 0.000081332
error: 0.000080044
error: 0.000078795
error: 0.000077585
error: 0.000076412
error: 0.000075273
error: 0.000074168
error: 0.000073095
error: 0.000072052
error: 0.000071038
error: 0.000070053
error: 0.000069094
error: 0.000068162
error: 0.000067254
error: 0.000066370
error: 0.000065509
error: 0.000064670
error: 0.000063852
error: 0.000063054
error: 0.000062276
error: 0.000061517
error: 0.000060777
error: 0.000060053
error: 0.000059347
error: 0.000058658
error: 0.000057984
error: 0.000057325
error: 0.000056681
error: 0.000056052
error: 0.000055436
error: 0.000054833
error: 0.000054244
error: 0.000053667
error: 0.000053102
error: 0.000052549
error: 0.000052007
error: 0.000051477
[81, 5.5, 2.4, 3.8] : Iris-virginica -> Iris-virginica
[86, 6.0, 3.4, 4.5] : Iris-virginica -> Iris-virginica
[101, 6.3, 3.3, 6.0] : Iris-virginica -> Iris-virginica
[91, 5.5, 2.6, 4.4] : Iris-virginica -> Iris-virginica
[52, 6.4, 3.2, 4.5] : Iris-virginica -> Iris-virginica
[93, 5.8, 2.6, 4.0] : Iris-virginica -> Iris-virginica
[12, 4.8, 3.4, 1.6] : Iris-virginica -> Iris-virginica
[2, 4.9, 3.0, 1.4] : Iris-virginica -> Iris-virginica
[118, 7.7, 3.8, 6.7] : Iris-virginica -> Iris-virginica
[144, 6.8, 3.2, 5.9] : Iris-virginica -> Iris-virginica
[8, 5.0, 3.4, 1.5] : Iris-virginica -> Iris-virginica
[34, 5.5, 4.2, 1.4] : Iris-virginica -> Iris-virginica
[147, 6.3, 2.5, 5.0] : Iris-virginica -> Iris-virginica
[111, 6.5, 3.2, 5.1] : Iris-virginica -> Iris-virginica
[57, 6.3, 3.3, 4.7] : Iris-virginica -> Iris-virginica
[62, 5.9, 3.0, 4.2] : Iris-virginica -> Iris-virginica
[71, 5.9, 3.2, 4.8] : Iris-virginica -> Iris-virginica
[27, 5.0, 3.4, 1.6] : Iris-virginica -> Iris-virginica
[117, 6.5, 3.0, 5.5] : Iris-virginica -> Iris-virginica
[82, 5.5, 2.4, 3.7] : Iris-virginica -> Iris-virginica
[79, 6.0, 2.9, 4.5] : Iris-virginica -> Iris-virginica
[41, 5.0, 3.5, 1.3] : Iris-virginica -> Iris-virginica
[140, 6.9, 3.1, 5.4] : Iris-virginica -> Iris-virginica
[46, 4.8, 3.0, 1.4] : Iris-virginica -> Iris-virginica
[13, 4.8, 3.0, 1.4] : Iris-virginica -> Iris-virginica
[72, 6.1, 2.8, 4.0] : Iris-virginica -> Iris-virginica
[121, 6.9, 3.2, 5.7] : Iris-virginica -> Iris-virginica
[85, 5.4, 3.0, 4.5] : Iris-virginica -> Iris-virginica
[19, 5.7, 3.8, 1.7] : Iris-virginica -> Iris-virginica
[26, 5.0, 3.0, 1.6] : Iris-virginica -> Iris-virginica
[80, 5.7, 2.6, 3.5] : Iris-virginica -> Iris-virginica
[38, 4.9, 3.6, 1.4] : Iris-virginica -> Iris-virginica
[65, 5.6, 2.9, 3.6] : Iris-virginica -> Iris-virginica
[25, 4.8, 3.4, 1.9] : Iris-virginica -> Iris-virginica
[138, 6.4, 3.1, 5.5] : Iris-virginica -> Iris-virginica
[73, 6.3, 2.5, 4.9] : Iris-virginica -> Iris-virginica
[36, 5.0, 3.2, 1.2] : Iris-virginica -> Iris-virginica
[130, 7.2, 3.0, 5.8] : Iris-virginica -> Iris-virginica
[56, 5.7, 2.8, 4.5] : Iris-virginica -> Iris-virginica
[92, 6.1, 3.0, 4.6] : Iris-virginica -> Iris-virginica
[123, 7.7, 2.8, 6.7] : Iris-virginica -> Iris-virginica
[78, 6.7, 3.0, 5.0] : Iris-virginica -> Iris-virginica
[104, 6.3, 2.9, 5.6] : Iris-virginica -> Iris-virginica
[125, 6.7, 3.3, 5.7] : Iris-virginica -> Iris-virginica
[131, 7.4, 2.8, 6.1] : Iris-virginica -> Iris-virginica
[67, 5.6, 3.0, 4.5] : Iris-virginica -> Iris-virginica
[11, 5.4, 3.7, 1.5] : Iris-virginica -> Iris-virginica
[108, 7.3, 2.9, 6.3] : Iris-virginica -> Iris-virginica
[99, 5.1, 2.5, 3.0] : Iris-virginica -> Iris-virginica
accuracy: 1.000000000

 

转载:https://blog.csdn.net/qq_42570457/article/details/81454512

你可能感兴趣的:(机器学习)