tensorflow batch nornalization的理解以及实现

用于最中执行batch normalization的函数
tf.nn.batch_normalization(
x,
mean,
variance,
offset,
scale,
variance_epsilon,
name=None
)

参数:
x是input输入样本
mean是样本均值
variance是样本方差
offset是样本偏移(相加一个转化值)
scale是缩放(默认为1)
variance_epsilon是为了避免分母为0,添加的一个极小值
输出的计算公式为:
y = scale * (x - mean) / var + offset


def moments(
x,
axes,
shift=None, # pylint: disable=unused-argument
name=None,
keep_dims=False):

参数:
x:一个tensor张量,即我们的输入数据
axes:一个int型数组,它用来指定我们计算均值和方差的轴(这里不好理解,可以结合下面的例子)
shift:当前实现中并没有用到
name:用作计算moment操作的名称
keep_dims:输出和输入是否保持相同的维度

返回:
两个tensor张量:均值和方差


def mean_var2tensor(input_variable):
    v_shape = input_variable.get_shape()
    axis = [len(v_shape) - 1]
    v_mean, v_var = tf.nn.moments(input_variable, axes=axis, keep_dims=True)
    return v_mean, v_var


def mean_var2numpy(input_variable):
    v_shape = input_variable.get_shape()
    axis = [len(v_shape) - 1]
    v_mean, v_var = tf.nn.moments(input_variable, axes=axis, keep_dims=True)
    return v_mean, v_var


def my_batch_normlization(input_x):
    # Batch Normalize
    x_shape = input_x.get_shape()
    axis = [len(x_shape) - 1]
    x_mean, x_var = tf.nn.moments(input_x, axes=axis, keep_dims=True)
    scale = tf.constant(0.1)   # 所有的batch 使用同一个scale因子
    shift = tf.constant(0.001)  # 所有的batch 使用同一个shift项
    epsilon = 0.001

    out_x = tf.nn.batch_normalization(input_x, x_mean, x_var, shift, scale, epsilon)
    return out_x

你可能感兴趣的:(python,tensorflow学习)