Batch Normalization(BN)Python实现

Batch Normalization(BN)Python实现_第1张图片
Batch Normalization(BN)Python实现_第2张图片
Batch Normalization(BN)Python实现_第3张图片

def Batchnorm_simple_for_train(x, gamma, beta, bn_param):
"""
param:x    : 输入数据,设shape(B,L)
param:gama : 缩放因子  γ
param:beta : 平移因子  β
param:bn_param   : batchnorm所需要的一些参数
	eps      : 接近0的数,防止分母出现0
	momentum : 动量参数,一般为0.9, 0.99, 0.999
	running_mean :滑动平均的方式计算新的均值,训练时计算,为测试数据做准备
	running_var  : 滑动平均的方式计算新的方差,训练时计算,为测试数据做准备
"""
	running_mean = bn_param['running_mean']  #shape = [B]
    running_var = bn_param['running_var']    #shape = [B]
	results = 0. # 建立一个新的变量
    
	x_mean=x.mean(axis=0)  # 计算x的均值
    x_var=x.var(axis=0)    # 计算方差
    x_normalized=(x-x_mean)/np.sqrt(x_var+eps)       # 归一化
    results = gamma * x_normalized + beta            # 缩放平移

    running_mean = momentum * running_mean + (1 - momentum) * x_mean
    running_var = momentum * running_var + (1 - momentum) * x_var
    
    #记录新的值
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var 
    
	return results , bn_param

滑动平均法记录mean和var,用于测试。
Batch Normalization(BN)Python实现_第4张图片

def Batchnorm_simple_for_test(x, gamma, beta, bn_param):
"""
param:x    : 输入数据,设shape(B,L)
param:gama : 缩放因子  γ
param:beta : 平移因子  β
param:bn_param   : batchnorm所需要的一些参数
	eps      : 接近0的数,防止分母出现0
	momentum : 动量参数,一般为0.9, 0.99, 0.999
	running_mean :滑动平均的方式计算新的均值,训练时计算,为测试数据做准备
	running_var  : 滑动平均的方式计算新的方差,训练时计算,为测试数据做准备
"""
	running_mean = bn_param['running_mean']  #shape = [B]
    running_var = bn_param['running_var']    #shape = [B]
	results = 0. # 建立一个新的变量
   
    x_normalized=(x-running_mean )/np.sqrt(running_var +eps)       # 归一化
    results = gamma * x_normalized + beta            # 缩放平移
    
	return results , bn_param

【深度学习】批归一化(Batch Normalization):https://blog.csdn.net/vict_wang/article/details/88075861
Batch Normalization原理与python实现:https://zhuanlan.zhihu.com/p/100672008
基础 | batchnorm原理及代码详解:https://blog.csdn.net/qq_25737169/article/details/79048516

你可能感兴趣的:(深度学习,python,机器学习,深度学习,人工智能,numpy)