batch_norm对输入的每一个特征进行归一化,再进行缩放和平移(gamma和beta的作用)
这里主要记录其实现。
比较简单,就是求x的均值和方差,然后按如下公式求出
这里还记录了均值和方差的moving average,用于测试时使用。
def batchnorm_forward(x, gamma, beta, bn_param):
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
if mode == 'train':
sample_mean = np.mean(x, axis = 0)
sample_var = np.var(x, axis = 0)
x_hat = (x - sample_mean) / np.sqrt(sample_var + eps)
out = gamma * x_hat + beta
cache = (x, gamma, beta, x_hat, sample_mean, sample_var, eps)
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
elif mode == 'test':
normed_x = (x - running_mean) / (running_var + eps)
shifted_x = normed_x * gamma + beta
out = shifted_x
else:
raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
# Store the updated running means back into bn_param
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache
反向传播主要就是要计算偏导数,这里有两种方法。
方法一:
根据正向传播的步骤,一步步地使用链式法则,遇到有分支的,分别计算然后求和。正向计算时的计算图如下图:
def batchnorm_backward(dout, cache):
# 解压中间变量
x, gamma, beta, xhat, mean, var, eps = cache
xmu = x - mean
sqrtvar = np.sqrt(var + eps)
ivar = 1./sqrtvar
N,D = dout.shape
#step6
dbeta = np.sum(dout, axis=0)
dgammax = dout
dgamma = np.sum(dgammax*xhat, axis=0)
dxhat = dgammax * gamma
#step5
divar = np.sum(dxhat*xmu, axis=0)
dxmu1 = dxhat * ivar #注意这是xmu的一个支路
#step4
dsqrtvar = -1. /(sqrtvar**2) * divar
dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar
#step3
dsq = 1. /N * np.ones((N,D)) * dvar
dxmu2 = 2 * xmu * dsq #注意这是xmu的第二个支路
#step2
dx1 = (dxmu1 + dxmu2) #注意这是x的一个支路
#step1
dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)
dx2 = 1. /N * np.ones((N,D)) * dmu #注意这是x的第二个支路
#step0 done!
dx = dx1 + dx2
return dx, dgamma, dbeta
方法二:把上面的步骤进行合并化简,或者直接从原公式计算偏导数
def batchnorm_backward_alt(dout, cache):
x, gamma, beta, x_hat, sample_mean, sample_var, eps = cache
m = dout.shape[0] # m is N here
dxhat = dout * gamma # (N, D)
dvar = (dxhat * (x-sample_mean) * (-0.5) * np.power(sample_var+eps, -1.5)).sum(axis = 0) # (D,)
dmean = np.sum(dxhat * (-1) * np.power(sample_var + eps, -0.5), axis = 0)
dmean += dvar * np.sum(-2 * (x - sample_mean), axis = 0) / m
dx = dxhat * np.power(sample_var + eps, -0.5) + dvar*2*(x - sample_mean) / m + dmean / m
dgamma = np.sum(dout * x_hat, axis = 0)
dbeta = np.sum(dout, axis = 0)
return dx, dgamma, dbeta
方法二比方法一的计算步骤少,效率大概比方法一快2倍