tensorflow 滑动平均模型 ExponentialMovingAverage

滑动平均模型对于采用GradientDescent或Momentum训练的神经网络的表现都有一定程度上的提升。

原理:在训练神经网络时,不断保持和更新每个参数的滑动平均值,在验证和测试时,参数的值使用其滑动平均值,能有效提高神经网络的准确率。

tf.train.ExponentialMovingAverage

tensorflow官网地址:https://www.tensorflow.org/versions/r0.12/api_docs/python/train/moving_averages

tensorflow中提供了tf.train.ExponentialMovingAverage来实现滑动平均模型,他使用指数衰减来计算变量的移动平均值。

tf.train.ExponentialMovingAverage.__init__(self, decay, num_updates=None, zero_debias=False, name="ExponentialMovingAverage"):

decay是衰减率

num_updates是ExponentialMovingAverage提供用来动态设置decay的参数,当初始化时提供了参数,即不为none时,每次的衰减率是:

min { decay , ( 1 + num_updates ) / ( 10 + num_updates ) }

apply()方法添加了训练变量的影子副本,并保持了其影子副本中训练变量的移动平均值操作。在每次训练之后调用此操作,更新移动平均值。

average()和average_name()方法可以获取影子变量及其名称。

在创建ExponentialMovingAverage对象时,需指定衰减率(decay),用于控制模型的更新速度。影子变量的初始值与训练变量的初始值相同。当运行变量更新时,每个影子变量都会更新为:

shadow_variable = decay * shadow_variable + (1 - decay) * variable

decay设置为接近1的值比较合理,通常为:0.999,0.9999等

滑动平均的原理理解

    # -*- coding: utf-8 -*-  
    """ 
    @author: tz_zs 
     
    滑动平均模型 
    """  
    import tensorflow as tf  
      
    # 定义一个变量,用于滑动平均计算、  
    v1 = tf.Variable(0, dtype=tf.float32)  
    # 定义一个变量step,表示迭代的轮数,用于动态控制衰减率  
    step = tf.Variable(0, trainable=False)  
      
    # 定义滑动平均的对象  
    ema = tf.train.ExponentialMovingAverage(0.99, step)  
      
    # 定义执行保持滑动平均的操作,  参数为一个列表格式  
    maintain_average_op = ema.apply([v1])  
      
    with tf.Session() as sess:  
        #  初始化所有变量  
        init_op = tf.global_variables_initializer()  
        sess.run(init_op)  
      
        # 通过ema.average(v1)获取滑动平均之后变量的取值,  
        # print(sess.run(v1))  # 0.0  
        # print(sess.run([ema.average_name(v1), ema.average(v1)]))  # [None, 0.0]  
        print(sess.run([v1, ema.average(v1)]))  # [0.0, 0.0]  
      
        # 更新变量v1的值为5  
        sess.run(tf.assign(v1, 5))  
        # 更新v1的滑动平均值,衰减率 min { decay , ( 1 + num_updates ) / ( 10 + num_updates ) }=0.1  
        # 所以v1的滑动平均会被更新为 0.1*0 + 0.9*5 = 4.5  
        sess.run(maintain_average_op)  
        # print(sess.run(v1))  # 5.0  
        # print(sess.run([ema.average_name(v1), ema.average(v1)]))  # [None, 4.5]  
        print(sess.run([v1, ema.average(v1)]))  # [5.0, 4.5]  
      
        # 更新step的值为10000。模拟迭代轮数  
        sess.run(tf.assign(step, 10000))  
        # 跟新v1的值为10  
        sess.run(tf.assign(v1, 10))  
        # 更新v1的滑动平均值。衰减率为 min { decay , ( 1 + num_updates ) / ( 10 + num_updates ) }得到 0.99  
        # 所以v1的滑动平均值会被更新为 0.99*4.5 + 0.01*10 = 4.555  
        sess.run(maintain_average_op)  
        print(sess.run([v1, ema.average(v1)]))  # [10.0, 4.5549998]  
      
        # 再次更新滑动平均值,将得到 0.99*4.555 + 0.01*10 =4.60945  
        sess.run(maintain_average_op)  
        print(sess.run([v1, ema.average(v1)]))  # [10.0, 4.6094499]  
    # -*- coding: utf-8 -*-  
    """ 
    @author: tz_zs 
     
    """  
    import tensorflow as tf  
      
    v1 = tf.Variable(10, dtype=tf.float32, name="v")  
      
    for variables in tf.global_variables():  # all_variables弃用了  
        print(variables)  #   
      
    ema = tf.train.ExponentialMovingAverage(0.99)  
    print(ema)  #   
      
    maintain_averages_op = ema.apply(tf.global_variables())  
    for variables in tf.global_variables():  
        print(variables)  #   
        #   
      
    with tf.Session() as sess:  
        tf.global_variables_initializer().run()  
        sess.run(tf.assign(v1, 1))  
        sess.run(maintain_averages_op)  
        print(sess.run([v1, ema.average(v1)]))  # [1.0, 9.9099998]  

    # -*- coding: utf-8 -*-  
    """ 
    @author: tz_zs 
     
    滑动平均值的存储和加载(持久化) 
    """  
    import tensorflow as tf  
      
    v1 = tf.Variable(10, dtype=tf.float32, name="v1")  
      
    for variables in tf.global_variables():  # all_variables弃用了  
        print(variables)  #   
      
    ema = tf.train.ExponentialMovingAverage(0.99)  
    print(ema)  #   
      
    maintain_averages_op = ema.apply(tf.global_variables())  
    for variables in tf.global_variables():  
        print(variables)  #   
        #   
      
    saver = tf.train.Saver()  
    print(saver)  #   
    with tf.Session() as sess:  
        tf.global_variables_initializer().run()  
        sess.run(tf.assign(v1, 1))  
        sess.run(maintain_averages_op)  
        print(sess.run([v1, ema.average(v1)]))  # [1.0, 9.9099998]  
      
        print(saver.save(sess, "/path/to/model.ckpt"))  # 持久化存储____会返回路径 /path/to/model.ckpt  
    #################################################################################################  
    print("#####" * 10)  
    print("加载")  
    #################################################################################################  
    var2 = tf.Variable(0, dtype=tf.float32, name="v2")  #   
    print(var2)  
    saver2 = tf.train.Saver({"v1/ExponentialMovingAverage": var2})  
    with tf.Session() as sess2:  
        saver2.restore(sess2, "/path/to/model.ckpt")  
        print(sess2.run(var2))  # 9.91 所以,成功加载了v1的滑动平均值  

也可以使用tensorflow提供的variable_to_restore函数完成加载

    var3 = tf.Variable(0, dtype=tf.float32, name="v1")  
    print(var3)  #   
    ema = tf.train.ExponentialMovingAverage(0.99)  
      
    print(ema.variables_to_restore())  # {'v1/ExponentialMovingAverage': }  
    saver = tf.train.Saver(ema.variables_to_restore())  
    with tf.Session() as sess:  
        saver.restore(sess, "/path/to/model.ckpt")  
        print(sess.run(var3))  # 9.91  

附录2:移动平均法相关知识(转)

来源地址:http://wiki.mbalib.com/wiki/%E7%A7%BB%E5%8A%A8%E5%B9%B3%E5%9D%87%E6%B3%95

 移动平均法又称滑动平均法、滑动平均模型法(Moving averageMA



什么是移动平均法?

   移动平均法是用一组最近的实际数据值来预测未来一期或几期内公司产品需求量、公司产能等的一种常用方法。移动平均法适用于即期预测。当产品需求既不快速增长也不快速下降,且不存在季节性因素时,移动平均法能有效地消除预测中的随机波动,是非常有用的。移动平均法根据预测时使用的各元素的权重不同
 
  移动平均法是一种简单平滑预测技术,它的基本思想是:根据时间序列资料、逐项推移,依次计算包含一定项数的序时平均值,以反映长期趋势的方法。因此,当时间序列的数值由于受周期变动和随机波动的影响,起伏较大,不易显示出事件的发展趋势时,使用移动平均法可以消除这些因素的影响,显示出事件的发展方向与趋势(即趋势线),然后依趋势线分析预测序列的长期趋势



移动平均法的种类

  移动平均法可以分为:简单移动平均和 加权移动平均
 

  一、简单移动平均法

 
   简单移动平均的各元素的权重都相等。简单的移动平均的计算公式如下:  Ft =( At-1 At-2 At-3 At-n /n 式中,
 
   · Ft-- 对下一期的预测值;
 
   · n-- 移动平均的时期个数;
 
   · At-1-- 前期实际值;
 
   · At-2 At-3 At-n 分别表示前两期、前三期直至前 n 期的实际值。
 

  二、加权移动平均法

 
  加权移动平均 给固定跨越期限内的每个变量值以不同的权重。其原理是:历史各期产品需求的数据信息对预测未来期内的需求量的作用是不一样的。除了以 n 为周期的周期性变化外,远离目标期的变量值影响力相对较低,故应给予较低的权重。 加权移动平均法 的计算公式如下:
 
   Ft w1At-1 w2At-2 w3At-3 wnAt-n 式中,
 
   · w1-- t-1 实际销售额的权重;
 
   · w2-- t-2 期实际销售额的权重;
 
   · wn-- t-n 期实际销售额的权
 
   · n-- 预测的时期数 w1  w2  wn 1
 
  在运用加权平均法时,权重的选择是一个应该注意的问题。经验法和试算法是选择权重的最简单的方法。一般而言,最近期的数据最能预示未来的情况,因而权重应大些。例如,根据前一个月的利润和生产能力比起根据前几个月能更好的估测下个月的利润和生产能力。但是,如果数据是季节性的,则权重也应是季节性的。


 

移动平均法的优缺点

 
   使用移动平均法进行预测能平滑掉需求的突然波动对预测结果的影响。但移动平均法运用时也存在着如下问题:
 
   1 加大移动平均法的期数(即加大 n 值)会使平滑波动效果更好,但会使预测值对数据实际变动更不敏感;
 
   2 移动平均值并不能总是很好地反映出趋势。由于是平均值,预测值总是停留在过去的水平上而无法预计会导致将来更高或更低的波动;
 
   3 移动平均法要由大量的过去数据的记录。



移动平均法案例分析

 

  案例一:移动平均法在公交运行时间预测中的应用

 
   公交车运行时间原始数据的采集采用的是人工测试法,即由记录人员从起始点到终点跟踪每辆客车,并记录下车辆在每个站点之间的运行时间。行驶路线选用的是长春公交 306 路,始发站为长春大学,终点站为火车站。数据采集的日期是从 2001 4 3 4 5 。这三天属工作日,因为公交运行时间因时间的不同而有不同的结果。所以这些数据只作为预测工作日运行时间。采集的数据是该路从工农广场站点到桂林路站点之间的运行时间。
 
  ( 1 N 3-20 ,利用移动平均法预测得到的结果见表 1
 
  移动平均法预测表
 
 
K
N
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
6 40
15
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6 41
16
5
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6 41
17
4
4
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6 42
18
4
4
4
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6 43
19
4
4
4
4
4
 
 
 
 
 
 
 
 
 
 
 
 
 
6 44
20
4
4
4
4
4
4
 
 
 
 
 
 
 
 
 
 
 
 
6 45
21
4
4
4
4
4
4
4
 
 
 
 
 
 
 
 
 
 
 
6 4622
4
4
4
4
4
4
4
4
 
 
 
 
 
 
 
 
 
 
6 47
22
4
4
4
4
4
4
4
4
4
 
 
 
 
 
 
 
 
 
6 48
23
4
4
4
4
4
4
4
4
4
4
 
 
 
 
 
 
 
 
6 49
24
5
4
4
4
4
4
4
4
4
4
4
 
 
 
 
 
 
 
6 50
25
5
5
5
4
4
4
4
4
4
4
4
4
 
 
 
 
 
 
6 51
26
5
5
5
5
5
4
4
4
4
4
4
4
 
 
 
 
 
 
6 52
27
5
5
5
5
5
5
5
4
4
4
4
4
4
 
 
 
 
 
6 53
28
5
5
5
5
5
5
5
5
5
5
5
5
5
5
 
 
 
 
6 54
29
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
 
 
 
6 55
30
6
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
 
 
6 56
31
6
6
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
 
6 57
32
6
6
6
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
6 58
33
6
6
6
6
5
5
5
5
5
5
5
5
5
5
5
5
5
5
6 59
34
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 00
35
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 01
36
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 02
37
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 03
38
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 04
39
4
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 05
40
4
4
4
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 06
41
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
5
5
7 07
42
4
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
7 08
43
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
7 09
44
4
4
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
5
7 10
45
4
4
4
4
4
4
4
4
4
4
4
4
4
5
5
5
5
5
7 11
46
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
5
5
5
7 12
47
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
5
7 13
48
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
7 14
49
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
7 15
50
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
7 16
51
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
7 17
52
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
7 18
53
5
5
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
7 19
54
5
5
5
5
5
4
4
4
4
4
4
4
4
4
4
4
4
4
7 20
55
5
5
5
5
5
5
5
4
4
4
4
4
4
4
4
4
4
4
7 21
56
5
5
5
5
5
5
5
5
5
4
4
4
4
4
4
4
4
4
7 22
57
5
5
5
5
5
5
5
5
5
5
5
4
4
4
4
4
4
4
7 23
58
5
5
5
5
5
5
5
5
5
5
5
5
4
4
4
4
4
4
7 24
59
5
5
5
5
5
5
5
5
5
5
5
5
5
5
4
4
4
4
7 25
60
5
5
5
5
5
5
5
5
5
5
5
5
5
5
4
4
4
4
7 26
61
4
4
5
5
5
5
5
5
5
5
5
5
5
5
4
4
4
4
7 27
62
4
4
4
4
5
5
5
5
5
5
5
5
5
5
4
4
4
4
7 28
63
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
4
4
4
7 29
64
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
4
4
7 30
65
5
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
4
7 31
66
5
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
7 32
67
5
5
5
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
7 33
68
5
5
5
5
4
4
4
4
4
4
5
5
5
5
5
5
5
5
7 34
69
5
5
5
5
5
4
4
4
4
4
4
5
5
5
5
5
5
5
7 35
70
5
5
5
5
5
5
4
4
4
4
4
4
5
5
5
5
5
5
7 36
71
5
5
5
5
5
5
5
4
4
4
4
4
4
5
5
5
5
5
7 37
72
5
5
5
5
5
5
5
5
4
4
4
4
4
4
5
5
5
5
7 38
73
5
5
5
5
5
5
5
5
5
4
4
4
4
4
4
5
5
5
7 39
74
5
5
5
5
5
5
5
5
5
5
4
4
4
4
4
4
5
5
7 40
75
5
5
5
5
5
5
5
5
5
5
5
4
4
4
4
4
4
5
7 41
76
5
5
5
5
5
5
5
5
5
5
5
5
5
4
4
4
4
5
7 42
77
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 43
78
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 44
79
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 45
80
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 46
81
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 47
82
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 48
83
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 49
84
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 50
85
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 51
86
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 52
87
4
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 53
88
4
4
4
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 54
89
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
5
5
5
7 55
90
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
5
5
7 56
91
4
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
7 57
92
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
7 58
93
4
4
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
5
7 59
94
4
4
4
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
8 00
95
4
4
4
4
4
4
4
4
4
4
4
4
4
4
5
5
5
5
8 01
96
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
5
5
5
8 02
97
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
5
8 03
98
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 04
99
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 05
100
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 06
101
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 07
102
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 08
103
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 09
104
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 10
105
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 11
106
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 12
107
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 13
108
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 14
109
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 15
110
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8 16
111
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
8

你可能感兴趣的:(深度学习)