计算得到一个四维的向量,作为这个channel的均值,方差,然后在每个通道上分别用该点的像素值减均值除方差得到该点的像素值,此过程就是BN。最后将其接入到激活函数中。如下图:
import torch
from torch import nn
# 一维BN
d1 = torch.rand([2,3,4]) #BCW
bn1 = nn.BatchNorm1d(3, momentum=1)
res = bn1(d1)
print(res.shape)
#二维BN(常用)
d2 = torch.rand([2,3,4,5]) #BCHW
bn2 = nn.BatchNorm2d(3, momentum=1) # momentum一般设置0.9
res = bn2(d2)
print(res.shape)
print(bn2.running_mean) #3个chanel均值
print(bn2.running_var) #3个chanel方差
结果:
torch.Size([2, 3, 4])
torch.Size([2, 3, 4, 5])
tensor([0.5622, 0.5005, 0.4583])
tensor([0.0914, 0.0774, 0.0840])
momentum参数:momentum参数,该参数作用于mean和variance的计算上,这里保留了历史batch里的mean和variance值,即moving mean和moving variance,借鉴优化算法里的momentum算法将历史batch里的mean和variance的作用延续到当前batch.一般momentum的值为0.9,0.99等.多个batch后,即多个0.9连乘后,最早的batch的影响会变弱。
def batchnorm_forward(x, gamma, beta, bn_param):
"""
Forward pass for batch normalization
Input:
- x: Data of shape (N, D)
- gamma: Scale parameter of shape (D,)
- beta: Shift parameter of shape (D,)
- bn_param: Dictionary with the following keys:
- mode: 'train' or 'test'
- eps: Constant for numeric stability
- momentum: Constant for running mean / variance
- running_mean: Array of shape(D,) giving running mean of features
- running_var Array of shape(D,) giving running variance of features
Returns a tuple of:
- out: of shape (N, D)
- cache: A tuple of values needed in the backward pass
"""
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
out, cache = None, None
if mode == 'train':
sample_mean = np.mean(x, axis=0) #np.mean([[1,2],[3,4]])->[2,3]
sample_var = np.var(x, axis=0)
out_ = (x - sample_mean) / np.sqrt(sample_var + eps)
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
out = gamma * out_ + beta
cache = (out_, x, sample_var, sample_mean, eps, gamma, beta)
elif mode == 'test':
# scale = gamma / np.sqrt(running_var + eps)
# out = x * scale + (beta - running_mean * scale)
x_hat = (x - running_mean) / (np.sqrt(running_var + eps))
out = gamma * x_hat + beta
else:
raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
# Store the updated running means back into bn_param
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache
反向传播代码:
def batchnorm_backward(dout, cache):
"""
Backward pass for batch normalization.
Inputs:
- dout: Upstream derivatives, of shape (N, D)
- cache: Variable of intermediates from batchnorm_forward.
Returns a tuple of:
- dx: Gradient with respect to inputs x, of shape (N, D)
- dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
- dbeta: Gradient with respect to shift parameter beta, of shape (D,)
"""
dx, dgamma, dbeta = None, None, None
out_, x, sample_var, sample_mean, eps, gamma, beta = cache
N = x.shape[0]
dout_ = gamma * dout
dvar = np.sum(dout_ * (x - sample_mean) * -0.5 * (sample_var + eps) ** -1.5, axis=0)
dx_ = 1 / np.sqrt(sample_var + eps)
dvar_ = 2 * (x - sample_mean) / N
# intermediate for convenient calculation
di = dout_ * dx_ + dvar * dvar_
dmean = -1 * np.sum(di, axis=0)
dmean_ = np.ones_like(x) / N
dx = di + dmean * dmean_
dgamma = np.sum(dout * out_, axis=0)
dbeta = np.sum(dout, axis=0)
return dx, dgamma, dbeta