Linux Pytorch ResNet-18 cifar10 实践报告

Linux Pytorch ResNet-18 cifar10 实践报告

  • 硬件资源
  • 环境版本
  • 实验方法
    • 基本参数设置
  • 实验结果
  • 结果分析
    • 1. ResNet-v1 VS ResNet-v2
    • 2. ResNet-v2 VS ResNet-v2+TrivialAugment
    • 3. MixUp vs CutMix vs TrivialAugment

硬件资源

cpu: Intel(R) Core(TM) i5-7500 CPU @ 3.40GHz
显卡: 1080Ti
内存: 16G

环境版本

#系统信息
Distributor ID:	Ubuntu
Description:	Ubuntu 16.04.5 LTS
Release:	16.04
Codename:	xenial
#主要依赖		
torch              1.10.0
torchvision        0.11.1
#CUDA信息
~$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
~$ cat /usr/local/cuda/version.txt
CUDA Version 10.2.89

实验方法

基于PyTorch使用ResNet-18模型训练cifar10数据集

  1. 对比 ResNet-v1 和 ResNet-v2 的测试集准确率
  2. 对比三种当前比较先进的数据增强方法(MixUp、CutMix、TrivialAugment)的测试集准确率

基本参数设置

#训练集 测试集比例
5:1 即训练集50000张,测试集10000张
# 超参数设置
EPOCH = 100  # 遍历数据集次数
BATCH_SIZE = 512  # 批处理尺寸(batch_size)
LR = 0.1  # 学习率

# 基本的随机增强
RandomCrop
RandomHorizontalFlip

# 损失函数
CrossEntropyLoss #交叉熵
# 优化方式
optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
# 学习率迭代策略
#在指定的epoch值,[60, 90]处对学习率进行衰减,lr = lr * gamma
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[60,90], gamma=0.1)

实验结果

实验 测试集准确率
ResNet-v1 94.13%
ResNet-v2 94.23%
ResNet-v2 + RandomMixup 95.10%
ResNet-v2 + RandomCutmix 95.46%
ResNet-v2 + TrivialAugment 95.27%

Linux Pytorch ResNet-18 cifar10 实践报告_第1张图片

最终实验结果表明:

  • ResNet-v2的网络模型优于ResNet-v1。
  • 三种当前比较先进的数据增强方法(MixUp、CutMix、TrivialAugment)都有不俗的作用,一定程度上提升了准确率。
  • 针对cifar10数据集,用ResNet-v2的ResNet-18训练时,CutMix的数据增强手段最优。

结果分析

1. ResNet-v1 VS ResNet-v2

Linux Pytorch ResNet-18 cifar10 实践报告_第2张图片
由上图可以看出,ResNet-v2相比于ResNet-v1,训练集(蓝线)和测试集准确率(粉红线)都更早更快得达到一个比较好的效果,也就是训练更容易,最终的测试集准确率也超出一点点(94.23% vs 94.13%)。
Linux Pytorch ResNet-18 cifar10 实践报告_第3张图片

ResNet-v2重新设计了一种残差网络基本单元(unit)就是将激活函数(先BN再ReLU)移到权值层之前,形成一种“预激活(pre-activation)”的方式,如上图(b),而不是ResNet-v1中常规的“后激活(post-activation)”方式,如上图(a),并且预激活的单元中的所有权值层的输入都是归一化的信号。这使得网络更易于训练并且泛化性能也得到提升。

2. ResNet-v2 VS ResNet-v2+TrivialAugment

Linux Pytorch ResNet-18 cifar10 实践报告_第4张图片
从上图可以看到,

  • ResNet-v2训练过程中,训练集准确率(绿线)一直高于测试集准确率(红线),并且最后训练集准确率接近100%,而测试集准确率仅有94.23%。
  • 加上TrivialAugment的数据增强方法后,训练全程测试集准确率(粉红线)一直高于训练集准确率(蓝线),最后训练集准确率达到92%左右,而测试集准确率达到95.27%。
  • 以上结果说明加上TrivialAugment的数据增强方法后,网络过拟合程度大大减小。其原因不难看出是因为强大的数据增强扩充了数据集,增强了网络的泛化能力。

3. MixUp vs CutMix vs TrivialAugment

Linux Pytorch ResNet-18 cifar10 实践报告_第5张图片
几种数据增强的区别:MixUp vs CutMix vs TrivialAugment

  • MixUp:将随机的两张样本按比例混合,分类的结果按比例分配
    Linux Pytorch ResNet-18 cifar10 实践报告_第6张图片

  • CutMix:将一部分区域cut掉但不填充0像素而是随机填充训练集中的其他数据的区域像素值,分类结果按一定的比例分配
    Linux Pytorch ResNet-18 cifar10 实践报告_第7张图片

  • TrivialAugment:每次随机选择一个图像增强操作,然后随机确定它的增强幅度,并对图像进行增强。由于没有任何超参数,所以不需要任何搜索
    Linux Pytorch ResNet-18 cifar10 实践报告_第8张图片

  • TrivialAugment通过组合多种数据增强手段,随即增强强度,可能会有不稳定因素

  • mixup是将两张图按比例进行插值来混合样本,cutmix是采用cut部分区域再补丁的形式去混合图像,不会有图像混合后不自然的情形

  • cutmix通过要求模型从局部视图识别对象,对cut区域中添加其他样本的信息,能够进一步增强模型的定位能力

  • cutmix不会有图像混合后不自然的情形,能够提升模型分类的表现,最终测试准确率也更高一些

你可能感兴趣的:(pytorch,linux,cifar10,图像分类,数据增强)