模型的指数移动平均EMA

1、概念

指数移动平均(Exponential Moving Average,EMA),也叫权重移动平均(Weighted Moving Average),是一种给予近期数据更高权重的平均方法。

就是说,用原理数据来影响现在数据的更新。

通俗版本理解:EMA是将每次梯度更新后的权值和前一次的权重进行联系,使得本次更新收到上次权值的影响。

2、原理

公式:v_{t}=\partial \cdot v_{t-1}+(1-\partial)\cdot \theta_{t}

a代表衰减率,该衰减率用于控制模型更新的速度,一般设为0.9-0.999。该值越大表示与上一次的影响越大,本次权重变化越小,与上次权重越接近,越稳定。

\theta_{t}表示本次通过计算的权重,也可以认为是“虚拟权重”

v_{t-1}表示上次的模型权重

3、应用

应用场景:在验证或者推理时使用。  滑动平均可以使模型在测试数据上更健壮(robust)。

现状:据说tensorflow中有对应的API,pytorch中没有,不过可以用类实现。网上有实现的版本。

TensorFlow 提供了 tf.train.ExponentialMovingAverage来实现滑动平均。

pytorch版本的实现代码,从博客中找的,自己并没有验证:

class EMA():
    def __init__(self, decay):
        self.decay = decay
        self.shadow = {}

    def register(self, name, val):
        self.shadow[name] = val.clone()

    def get(self, name):
        return self.shadow[name]

    def update(self, name, x):
        assert name in self.shadow
        new_average = (1.0 - self.decay) * x + self.decay * self.shadow[name]
        self.shadow[name] = new_average.clone()

你可能感兴趣的:(pytorch,深度学习,pytorch,神经网络)