神经网络训练相关参数设置

batch_size:

机器学习使用训练数据进行学习,针对训练数据计算损失函数的值,找出使该值尽可能小的参数。但当训练数据量非常大,这种情况下以全部数据为对象计算损失函数是不现实的。因此,我们从全部数据中选出一部分,作为全部数据的“近似”。神经网络的学习也是从训练数据中选出一批数据(称为 mini-batch,小批量),然后对每个mini-batch进行学习,值的大小与梯度下降的效率和结果直接相关。
比如,从60000个训练数据中随机选取100个数据,用这100个数据进行学习,这种学习方式成为 mini-batch 学习。用mini-batch的方法定义batch_size,把数据分成小批量,每小批的大小batch_size=100,共有600个这样的批次,即一个epoch梯度下降600次。

批量梯度下降(BGD):一个epoch训练所有的样本后更新一遍梯度。
假设训练样本共m个,你设置batch_size为1,则每个样本都是一个batch_size。
你设置batch_size为m,则所有样本组成这一个batch_size。
1与m也是两个极端。 当设置为m时,此时梯度下降称为批量梯度下降法。你可以理解为神经网络进行梯度下降时从最远的点,每次迭代需要遍历整个训练集。(所以需要很大的显存空间,如果你的样本数据不大,可以选择将batch_size设置为m)

随机梯度下降(SGD):每训练一个样本,更新一遍梯度。
当设置为1时,此时梯度下降称为随机梯度下降法。你可以理解为神经网络进行梯度下降时随机找一个点,每次迭代只处理一个训练数据(所以需要很长的时间来完成训练)

综合上述,选择一个合适大小的 batch_size是很重要的,因为计算机字符都是以2的指数次幂进行存储的,所以设置 batch_size时尽量选择例如 4, 8, 16, 32, 64, 128, 256 等。
batch_size越大,速度越快,精度越低(相同训练轮数)。

Batch_size的调参:
1.当有足够算力时,选取batch size为32或更小一些。
2.算力不够时,在效率和泛化性之间做trade-off,尽量选择更小的batch size。
3.当模型训练到尾声,想更精细化地提高成绩(比如论文实验/比赛到最后),有一个有用的trick,就是设置batch size为1,即做纯SGD,慢慢把error磨低。

当然增大Batch_size会加快速度,但是变相地需要更多的Epoch(轮数)去达到需要的精度。

yolov5的作者建议在显存足够的情况下Batch_size设置越大越好,可以使用YOLOv5 AutoBatch (NEW) 通过设置**–batch-size - 1**进行训练。从而找到显存最大利用率下的batch-size值. 根据您的训练设置,AutoBatch 将解决 90% 的 CUDA 内存利用率批处理大小。AutoBatch 是实验性的,仅适用于单 GPU 训练。如下图,可得出最大batch-size为179。来自该issue
神经网络训练相关参数设置_第1张图片


iteration:

迭代,即训练学习循环一遍(寻找最优参数(权重和偏置))。比如 iteration=30000,循环一遍即执行了30000次迭代。当 batch_size=100,可以说执行完一遍 iteration,即执行了30000次 batch_size。


epoch:

epoch 是一个单位。一个 epoch表示学习中所有训练数据均被使用过一次时的更新次数。比如,对于1000个训练数据,用大小为100个数据的mini-batch(batch_size=100)进行学习时,重复随机梯度下降法100次,所有的训练数据就都被“看过”了。此时,10次就是一个 epoch。(即:遍历一次所有数据,就称为一个 epoch)。
实例:
下面展示一些 内联代码片

训练数据量:60000
mini-batch方法:batch_size = 100
迭代次数:iteration = 30000
平均每个epoch的重复次数:60000 / 100 = 600
当迭代进行到600次时,即视为完成了一个epoch
30000 / 600 = 50

从这个实例可以看出,执行完一遍 iteration,完成了50个 epoch。

对于Epoch大小的确定,牵扯到了防止过拟合的一个方法:提前停止训练(early stopping)。
随着epoch次数增加,神经网络中的权重的更新次数也增加,模型从欠拟合变得过拟合。

trick:
可以先设定一个固定的Epoch大小(100轮)
一般当模型的loss不再持续减小,且精度不在10轮内提升,就可以提前停止训练了。(设置条件来停止epoch)


转自:
https://blog.csdn.net/qq_38358305/article/details/88643163
https://blog.csdn.net/stay_foolish12/article/details/107386434

你可能感兴趣的:(神经网络,深度学习,python)