误差反向传播算法

由于我们通过数值计算求得的权值梯度速度较慢,特别是当神经网络特别复杂时,我们求权值梯度的速度将会很慢很慢,为了提高效率,我们引入误差反向传播算法来提高求梯度的效率。

接下来以买苹果和橘子为例来对反向传播算法进行更深入的了解,

误差反向传播算法_第1张图片

由于求解总价格需要运用乘法和加法,所以我们需要定义一个乘法类和加法类,每个类中都需要定义正向传播forward和反向传播backward,每一个运算符就相当于一层,所以我们接下来一共需要定义四层,其中三层乘法类,一层加法类。

具体代码如下:

'''
之前我们求神经网络的梯度是通过数值微分来计算的,接下来我们将实现一个更高效的计算权值参数的梯度方法
'''
#误差反向传播法
#简单层的实现


'''
apple=100
apple_num=2
tax=1.1
#layer
mul_apple_layer=Mullayer()
mul_tax_layer=Mullayer()

#forward
apple_price=mul_apple_layer.forward(apple,apple_num)
price=mul_tax_layer.forward(apple_price,tax)
print(price)#220

#此外,关于各个变量的导数可由backward()求出
#backward
dprice=1
dapple_price,dtax=mul_tax_layer.backward(dprice)
dapple,dapple_num=mul_apple_layer.backward(dapple_price)
print(dapple,dapple_num,dtax)#2.2,110,200
'''
#首先实现乘法层(用类实现)
class Mullayer:
    def __init__(self):
        self.x=None
        self.y=None

    def forward(self,x,y):#正向传播
        self.x=x
        self.y=y
        out=x*y

        return out
    def backward(self,dout):
        dx=dout*self.y
        dy=dout*self.x
        return dx,dy
#加法层的实现
class AddLayer:
    def __init__(self):
        pass
    def forward(self,x,y):
        out=x+y
        return out
    def backward(self,dout):
        dx=dout*1
        dy=dout*1
        return dx,dy

#现在使用加法层和乘法层实现购买2个苹果和3个橘子的例子
apple=100
apple_num=2
orange=150
orange_num=3
tax=1.1

#layer
mul_apple_layer=Mullayer()
mul_orange_layer=Mullayer()
add_apple_orange_layer=AddLayer()
mul_tax_layer=Mullayer()

#forward
apple_price=mul_apple_layer.forward(apple,apple_num)
orange_price=mul_orange_layer.forward(orange,orange_num)
all_price=add_apple_orange_layer.forward(apple_price,orange_price)
price=mul_tax_layer.forward(all_price,tax)

#backward
dprice=1
dall_price,dtax=mul_tax_layer.backward(dprice)
dapple_price,dorange_price=add_apple_orange_layer.backward(dall_price)
dapple,dapple_num=mul_apple_layer.backward(dapple_price)
dorange,dorange_num=mul_orange_layer.backward(dorange_price)

print(price)#715
print(dapple,dapple_num,dorange,dorange_num,dtax)#2.2,110,3.3,165,650

你可能感兴趣的:(算法,python)