学习: 人工智能实践:Tensorflow笔记(四) 4.1

  1 #coding:utf-8
  2 #此程序是4.1-损失函数单元,运用MSE作为损失函数的
  3 #题目是预测日销量y,因变量是x1,x2。
  4 #数据样本,Y_是y_=x1+x2,噪声-0.05-0.05产生
  5 #导入模块
  6 import tensorflow as tf
  7 import numpy as np
  8 BATCH_SIZE=8
  9 #SEED赋值使得产生的随机数与老师的一样,方便对比检查,实际运用时不用这样
 10 SEED=23455
 11 
 12 rdm=np.random.RandomState(SEED)
 13 X=rdm.rand(32,2)
 14 Y_=[[x1+x2+rdm.rand()/10.0-0.05] for (x1,x2) in X]
 15 
 16 #1定义神经网络的输入、参数和输出,定义前向传播过程。
 17 x=tf.placeholder(tf.float32,shape=(None,2))
 18 y_=tf.placeholder(tf.float32,shape=(None,1))
 19 w1=tf.Variable(tf.random_normal([2,1],stddev=1,seed=1))
 20 y=tf.matmul(x,w1)
 21 
 22 #2定义损失函数以及反向传播方法,注意,此处是定义,相当于把枪配置好
 23 #定义损失函数为MSE,反向传播方法为梯度下降GD
 24 loss_mse=tf.reduce_mean(tf.square(y_-y))
 25 train_step=tf.train.GradientDescentOptimizer(0.001).minimize(loss_mse)
 26 #tf把神经网络过程统一规划了,从而可以方便调用函数
 27 
 28 #3生成回话,训练STEPS轮
 29 #这句话是模板了,哈哈
 30 with tf.Session() as sess:
 31     init_op=tf.global_variables_initializer()
 32     sess.run(init_op)
 33     STEPS=20000
 34     #PYTHON的for语法也有意思,哈哈
 35     for i in range(STEPS):
 36         start=(i*BATCH_SIZE)%32
 37         end=(i*BATCH_SIZE)%32+BATCH_SIZE
 38         sess.run(train_step,feed_dict={x:X[start:end],y_:Y_[start:end]})
 39         if i%500==0:
 40             print "AFTER %d training steps,w1 is: "%(i)
 41             print sess.run(w1),"\n"
 42     print "Final w1 is :\n", sess.run(w1)
 43 

部分运行结果

Final w1 is :
[[0.98019385]
 [1.0159807 ]]

2.BATCH_SIZE的作用

本文取值是8,当然由于样本是32组,BATCH_SIZE取值是1到32之间

一般来说,在合理的范围之内,越大的 batch size 使下降方向越准确,震荡越小;batch size 如果过大,则可能会出现局部最优的情况。小的 bath size 引入的随机性更大,难以达到收敛,极少数情况下可能会效果变好。

在合理范围内,增大 Batch_Size 有何好处

  • 内存利用率提高了,大矩阵乘法的并行化效率提高。
  • 跑完一次 epoch(全数据集)所需的迭代次数减少,对于相同数据量的处理速度进一步加快。
  • 在一定范围内,一般来说 Batch_Size 越大,其确定的下降方向越准,引起训练震荡越小。

学习: 人工智能实践:Tensorflow笔记(四) 4.1_第1张图片
盲目增大 Batch_Size 有何坏处

  • 内存利用率提高了,但是内存容量可能撑不住了。
  • 跑完一次 epoch(全数据集)所需的迭代次数减少,要想达到相同的精度,其所花费的时间大大增加了,从而对参数的修正也就显得更加缓慢。
  • Batch_Size 增大到一定程度,其确定的下降方向已经基本不再变化。

参考:https://blog.csdn.net/juronghui/article/details/78612653

 

你可能感兴趣的:(tensorflow,linux)