python3实现线性单元

理论知识见https://www.zybuluo.com/hanbingtao/note/448086

直接上python3的代码

#coding utf-8
import matplotlib.pyplot as plt
from functools import reduce

class perceptron(object):
    #初始化,输入训练数目,激活函数
    def __init__(self,input_num,activator):#activator为激活函数
        self.activator=activator
        self.weights=[0.0 for _ in range(input_num)]#权重初始化为0
        self.bias=0.0#偏置初始化为0.0
    #运算
    def operation(self,input_vec):
        #对激活函数中的参数做运算,x[0]代表input_vec,x[1]代表weights
        return self.activator(reduce(lambda a,b:a+b,map(lambda x:x[0]*x[1],zip(input_vec,self.weights)),0.0)+self.bias)#0.0为reduce的初始计算值
    #权值更新
    def updata(self,input_vec,output,label,rate):
        delta=label-output
        self.weights=list(map(lambda x:x[1]+rate*delta*x[0],zip(input_vec,self.weights)))#加上list跟python2有区别
        self.bias+=rate*delta
    #训练,输入数据及对应标签,迭代次数,学习率
    def train(self,input_vecs,labels,iteration_num,rate):
        for i in range(iteration_num):#iteration_num次迭代
            samples=zip(input_vecs,labels)#打包
            for (input_vec,label) in samples:
                output=self.operation(input_vec)#计算输出值
                self.updata(input_vec,output,label,rate)#更新
    #预测
    def predict(self,input_vec):
        return self.operation(input_vec)
    #打印权值,偏置
    def __str__(self):#内部函数
        return "weight: %s, bias: %f"%(self.weights,self.bias)#权值返回用%s
        
'''实现线性单元'''
#激活函数为线性函数
andActivator=lambda x:x

#得到训练数据
def getTrainData():
    input_vecs=[[5],[3],[8],[1.4],[10.1]]#可重用多次循环迭代
    labels=[5500,2300,7600,1800,11400]
    return input_vecs,labels
#训练感知机
def trainPerceptron():
    p=perceptron(1,andActivator)
    input_vecs,labels=getTrainData()
    p.train(input_vecs,labels,30,0.1)#100为迭代次数,0.1为学习率
    return p
#画图
def plot(linearUnit):
    input_vecs,labels=getTrainData()
    fig=plt.figure()
    ax=fig.add_subplot(111)
    ax.scatter(input_vecs,labels)#横坐标input_vecs,纵坐标labels
    weights=linearUnit.weights
    bias=linearUnit.bias
    x=range(0,15,1)#画0到15年的图像
    y=list(map(lambda x:weights[0]*x+bias,x))
    ax.plot(x,y)
    plt.show()
    
#主函数
if __name__=='__main__':
    train_perceptron=trainPerceptron()
    print(train_perceptron)
#测试
    print('工作3.4年,月薪%f'%train_perceptron.predict([3.4]))
    print('工作15年,月薪%f'%train_perceptron.predict([15]))
    print('工作1.5年,月薪%f'%train_perceptron.predict([1.5]))
    print('工作6.3年,月薪%f'%train_perceptron.predict([6.3]))
    plot(train_perceptron)
    

拟合得不好,见谅

python3实现线性单元_第1张图片

python3实现线性单元_第2张图片

你可能感兴趣的:(深度学习)