NASNet论文详解

NASNet,论文的全名叫做Learning Transferable Architectures for Scalable Image Recognition.

这一篇论文是对神经网络架构搜索开篇之作NAS的集成和发展,也是由谷歌的Zoph等人提出来的,针对NAS论文中的缺点进行改进,在分类精度和训练资源、时间上,都优于前者。

NASNet论文的基本设计思想是:

  1. 和NAS论文一样,采用controller RNN来预测子网络参数
  2. 第一次提出了Cell和Block的概念
  3. controller RNN不再用来预测每一层的网络参数,而是用来预测Cell里面的Block参数

首先介绍一下什么是Cell和Block。Cell可以看做是整体网络架构里面的一个单元块,类似ResNet架构的残差块或者MobileNet V2的bottleneck,整个网络就是由这些单元块堆叠连接而成。

Cell分两种:Normal和Reduction。当输入特征和输出特征的分辨率是一致时,采用Normal Cell,当输入特征的分辨率是输入特征的一半时,采用Reduction Cell。Reduction Cell的设计方法Normal Cell基本一样,只是在输入特征上添加了一个stride=2的卷积操作,降低分辨率。在整体网络架构中,Normal Cell和Reduction Cell的设计原则是每N个Normal Cell中插入一个Reduction Cell,如下图所示。
NASNet论文详解_第1张图片
图1. Cifar-10和ImageNet上的NASNet网络架构

Block是Cell里面的基本单元,共有B个(论文取5)。每个Block有两个输入,分别经过各自的operation之后再结合(相加或者衔接)作为输出,Block的输出称为隐状态。对于第 i i i个Block,输入的候选范围包括前面 i − 1 i-1 i1个Block的隐状态以及前两个Cell的输出,Block的操作的候选空间如下图所示。
NASNet论文详解_第2张图片
图2. Block操作的候选空间

与NAS论文里controller RNN预测每一个layer的操作参数不同,NASNet的controller RNN是用来预测Cell里面每一个Block的参数。具体如下图所示。
NASNet论文详解_第3张图片
图3. NASNet的controller RNN

Block的参数预测步骤有:

  1. 从输入候选范围内选择两个隐状态作为Block的两个输入
  2. 从操作候选空间选择operation作为步骤1中两个输入的操作
  3. 选择一个操作用来结合步骤2中的两个输出

预测步骤总共会循环B次,直至预测出Cell所有Block结构为止。

Controller RNN的训练方法和NAS论文中一样,也是通过验证集的精度作为reward来优化controller的参数,采用的强化学习中的PPO(Proximal Policy Optimization)算法。

在训练的时候,只选择一种Normal和Reduction Cell,同一个网络中相同类型的Cell结构是共享的,所以controller RNN只需要预测一个Cell的结构即可。从搜索空间的复杂度来看,这种方法设计极大地减小了搜索的次数和范围,这种思想被后来的其他NAS论文广泛引用,后面的博客介绍的其他方法会持续提到。

作者在训练的过程还加了一种额外的技巧,即先在小的数据集上(如Cifar-10)搜索Cell结构,等搜索结果出来后,再堆叠更多的Cell,应用在大数据集上(如ImageNet)。这样在搜索的过程中,子网络模型训练的时间便大幅减小,提高搜索的效率。

在Cifar-10数据上,论文使用了500个GPU,搜索了4天的时间。相比NAS论文的实验,搜索效率提升了7倍。在训练子网络时,采用Scheduled DropPath的方法,以一定的概率(随着迭代的次数线性增加)随机扔掉Cell里的某些路径。下图是NASNet搜索出来的Normal和Reduction Cell的结构。
NASNet论文详解_第4张图片
图5. 搜索出来的Normal和Reduction Cell结构图

论文把cifar-10上搜索出来的Cell结构迁移到ImageNet数据集上,表现出了很好的泛化能力。

以下两张图是NASNet搜索出来的Cell按照图1里的方式叠加成网络后训练出来的结果。可以看出,在同一参数量等级的模型上,NASNet比手工设计的网络模型精度更好,也比NAS论文的实验结果更优。
NASNet论文详解_第5张图片
图6. Cifar-10实验结果和对比

在这里插入图片描述
图7. ImageNet实验结果和对比

你可能感兴趣的:(AutoML,Deep,Learning)