batch_norm

    batch_norm对输入的每一个特征进行归一化,再进行缩放和平移(gamma和beta的作用)

    这里主要记录其实现。

   前向传播:

    比较简单,就是求x的均值和方差,然后按如下公式求出

\hat{x} = \frac{x-\mu }{\sqrt{\sigma^2 + \varepsilon }} 

y = \gamma\hat{x} + \beta

    这里还记录了均值和方差的moving average,用于测试时使用。

def batchnorm_forward(x, gamma, beta, bn_param):
    mode = bn_param['mode']
    eps = bn_param.get('eps', 1e-5)
    momentum = bn_param.get('momentum', 0.9)

    N, D = x.shape
    running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))

    if mode == 'train':
        sample_mean = np.mean(x, axis = 0)
        sample_var = np.var(x, axis = 0)
        x_hat = (x - sample_mean) / np.sqrt(sample_var + eps)
        out = gamma * x_hat + beta
        cache = (x, gamma, beta, x_hat, sample_mean, sample_var, eps)

        running_mean = momentum * running_mean + (1 - momentum) * sample_mean
        running_var = momentum * running_var + (1 - momentum) * sample_var
    elif mode == 'test':
        normed_x = (x - running_mean) / (running_var + eps)
        shifted_x = normed_x * gamma + beta
        out = shifted_x
    else:
        raise ValueError('Invalid forward batchnorm mode "%s"' % mode)

    # Store the updated running means back into bn_param
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var

    return out, cache

   反向传播:

    反向传播主要就是要计算偏导数,这里有两种方法。

    方法一:

    根据正向传播的步骤,一步步地使用链式法则,遇到有分支的,分别计算然后求和。正向计算时的计算图如下图:

batch_norm_第1张图片

def batchnorm_backward(dout, cache):
    # 解压中间变量
    x, gamma, beta, xhat, mean, var, eps = cache

    xmu = x - mean
    sqrtvar = np.sqrt(var + eps)
    ivar = 1./sqrtvar

    N,D = dout.shape

    #step6
    dbeta = np.sum(dout, axis=0)
    dgammax = dout
    dgamma = np.sum(dgammax*xhat, axis=0)
    dxhat = dgammax * gamma

    #step5
    divar = np.sum(dxhat*xmu, axis=0)
    dxmu1 = dxhat * ivar #注意这是xmu的一个支路

    #step4
    dsqrtvar = -1. /(sqrtvar**2) * divar
    dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar

    #step3
    dsq = 1. /N * np.ones((N,D)) * dvar
    dxmu2 = 2 * xmu * dsq #注意这是xmu的第二个支路

    #step2
    dx1 = (dxmu1 + dxmu2) #注意这是x的一个支路


    #step1
    dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)
    dx2 = 1. /N * np.ones((N,D)) * dmu #注意这是x的第二个支路

    #step0 done!
    dx = dx1 + dx2
   
    return dx, dgamma, dbeta

      方法二:把上面的步骤进行合并化简,或者直接从原公式计算偏导数

def batchnorm_backward_alt(dout, cache):
    x, gamma, beta, x_hat, sample_mean, sample_var, eps = cache
    m = dout.shape[0] # m is N here
    dxhat = dout * gamma # (N, D)
    dvar = (dxhat * (x-sample_mean) * (-0.5) * np.power(sample_var+eps, -1.5)).sum(axis = 0)  # (D,)
    dmean = np.sum(dxhat * (-1) * np.power(sample_var + eps, -0.5), axis = 0)
    dmean += dvar * np.sum(-2 * (x - sample_mean), axis = 0) / m
    dx = dxhat * np.power(sample_var + eps, -0.5) + dvar*2*(x - sample_mean) / m + dmean / m
    dgamma = np.sum(dout * x_hat, axis = 0)
    dbeta = np.sum(dout, axis = 0)

    return dx, dgamma, dbeta

    方法二比方法一的计算步骤少,效率大概比方法一快2倍

你可能感兴趣的:(机器学习)