PNAS:渐进式神经网络搜索,准确率预测,21倍加速 | ECCV2018

论文将核心放在搜索加速方面,基于NASNet,提出渐进式的PNAS搜索策略以及通过代理函数直接预测网络的准确率,极大地优化搜索逻辑,能够在搜索到相同性能的前提下,将搜索消耗降低21倍之多

来源:【晓飞的算法工程笔记】 公众号

论文: Progressive Neural Architecture Search

  • 论文地址:https://arxiv.org/abs/1712.00559

Introduction


  目前神经网络架构搜索主要有强化学习(RL)和进化算法(EA),尽管两种方法的效果都不错,但是搜索的资源耗费相当巨大,比如NASNet需要500块GPU计算4天。论文基于NASNet的搜索逻辑(具体可以看公众号之前的NASNet解读),提出启发式的搜索方法来搜索单元(cell)的结构,从简单的结构开始,在上一轮的结构上逐步生成更复杂的结构,直到搜索出符合大小的结构为止。由于训练和测试的耗时,论文将学习一个模型或代理函数来直接评价结构的好坏,每次预测个结构,然后使用代理函数选取top K个结构进行训练预测
  对比与直接搜索完整的结构,论文提出的渐进式(simple to complex)方法更效率。相对NASNet,仅搜索和训练1/5的模型,使用1/8的计算总量,此外还有以下几个特点:

  • 简单的结构训练速度快,能够快速得到一些初始结果用来训练代理函数
  • 代理函数只预测稍微比之前的训练集复杂一点的结构,完全能够胜任
  • 将搜索空间分解成更多的小搜索空间,能够允许搜索包含更多block的复杂单元

Architecture Search Space


Cell Topologies

  单元可以表示为B个block的有向图(DAG)结构,每个Block将两个tensor输入并输出一个tensor,可定义为一个5-tuple(,,,,),,为单元中之前的block的输出加前两个单元的输出。为tensor合并操作,由于之前实验发现RNN很少选择concatentaion,所以直接指定elememt-wise相加操作,不用再预测了。为算子,包含以下8种常见的操作,少于NASNet的13种:

  论文量化了搜索空间的大小,假设b’th block为,,,,假设,则。以此类推,由于单元具有对称性,可以对其进行相应的剪枝,比如有136个由单一block组成的单元,最终搜索空间为,远低于NASNet的,但仍是个庞大的搜索空间,需要高效的优化方法

From Cell to CNN

  将搜索到的单元按照预定的数量堆叠成卷积网络,如图1所示。卷积参数使用stride1或stride2,每当stride为2时,需要将卷积核数量增加两倍,最后用global average pooling接softmax进行分类。需要注意的是,NASNet特别将单元分为Normal和Reduction,但本文统一了。CIFAR-10网络的图片为,ImageNet的输入为或,在开头加了个卷积来降低计算消耗

Method


Progressive Neural Architecture Search

  论文提出渐进式的搜索顺序,先搜索简单的网络,构建(仅包含单个block)的所有单元加入到队列中进行并行的训练和验证,然后在基础上添加所有的block构建新单元,得到深度为2的候选单元结构集合。由于模型数量巨大,基于所有已经训练和验证的模型训练一个预测函数,用来直接预测所有候选单元结构的准确率,选择top K个候选单元加入到队列中,重复以上的步骤直到满足B个block,具体可以看算法1和图2

Performance Prediction with Surrogate Model

  如上所述,搜索需要预测单元的准确率的预测器,这个预测器至少需要3个特性:

  • Handle variable-sized inputs,预测器需要处理不定长的输入,例如在使用b block的单元训练的预测器需要能够预测b+1 block的单元
  • Correlated with true performance,预测器不需要预测绝对的准确率,只需要能够预测与准确率大致一样的顺序排列
  • Sample efficiency,预测器训练和验证的数据集越少越好,这意味着训练集是不足的

  为了满足以上特性,论文使用读取4b长度序列的LSTM进行预测,每个state的输入是或,最后的state通过全连接层和sigmoid来回归验证的准确率。论文也尝试使用简单的MLP来进行分类,将单元结构转化为固定长度的向量,首先将每个block的token编码为D维向量,然后将每个block concatenate成一个4D维向量,最后进行block间平均,训练使用L1 loss
  在训练时,一般直接在新数据上使用SGD进行小量迭代。但由于数据集非常小,论文将数据分成5分,每分包含4/5数据来训练5个预测器,这样能够有效降低预测的方差

Experiments and Results


Experimental Details

  对于MLP预测器,编码长度为100,使用两个全连接层,每层100个节点。对于RNN预测器,隐藏层和编码的长度都为100。每轮获取个神经网络,最大的单元为,第一个单元的初始卷积核数量为,每个单元重复,每个子网络训练20个周期,初始学习率为0.01,使用cosine decay

Performance of the Surrogate Predictors

  论文对比了不同预测器的效果,由于PNAS中预测器需要预测未训练过的大小的单元,所以实验既测试b block的准确率也测试b+1 block的准确率

  为随机选择的包含b block的单元集合,为集合总数,每个模型训练总共训练20 epoch,使用这个随机数据集来验证预测器的准确率,具体如算法2所示,返回模型集合的准确率。对于每个大小的每次实验,从选取个单元作为预测器的训练集进行训练或finetune,然后在训练集和未见过的上验证准确率

  图3展示了MLP预测器的结果,其它类型预测器的结果也大致一样,第一行和第二行分别为和的实际准确率和预测准确率的点图。从结果来看,预测器在训练集上的表现不错,但是在预测大型结构时效果不好

  论文用斯皮尔曼等级相关系数(Spearman rank correlation coefficient)来评估这些散点,从Table 1结果看来,在训练集上,RNN表现比MLP好,但在需要实际预测的大结构集合上,MLP表现稍微好点。另外,前面提到的ensembling操作也有一点帮助

Search Efficiency

  论文对比了PNAS与随机搜索和NAS的性能,对于PNAS,记录每轮选择的()和准确率,最后记录top 准确率,重复进行5次训练以保证结果的准确性。从图4结果看出,增加训练的模型数量,整体的表现都在稳步上升。PNAS的效果比随机搜索要好,与NAS相当,但NAS准确率提升速度较慢
  这里论文爆料NAS的模型选取其实是分两部分的,首先训练20 epoch选取top 250的模型,然后再训练300 epoch来选取best模型。在两种方法达到同一top-1准确率的前提下,NAS的计算总量(实际完整搜索的设定为20000)为,而PNAS则为,所以PNAS的计算总量仅为NAS的1/8

Results on CIFAR-10 Image Classification

  论文搜索到的最好结构如图1左,命名为PNASNet-5,在搜索完后,在保证保持模型总量在3M左右的前提下,对不同的N和F进行尝试,共训练300 epoch,学习率为0.025,使用cosine decay,在找到最佳的模型后对其训练600 epoch。另外,训练时,在模型2/3处添加Inception的辅助分类器,权重为0.4,也使用DropPath,随机值为0.4

  结果如Table 3所示,PNAS的准确率与NAS一致,但仅用了1/21的计算量。虽然AmoebaNet达到了最高的准确率,但是其需要花费63倍的搜索时间

Results on ImageNet Image Classification

  为了验证PNASNet-5的性能,进行了两种实验,两种实验都使用RMSProp优化器,label smoothing为0.1,在2/3位置使用权重为0.4的辅助分类器,softmax dropout为0.4,droppath为0.4:

  • Mobile,输入图片为,模型计算量限制在600M内
  • Large,输入图片为,与SOTA模型对比

  Mobile的结果如Table 4,PNASNet-5稍微比NASNet-A的表现稍好,比之前人工设定的网络要好,其中性能最佳的是AmoebaNet-C的性能最佳,但这个并非在CIFAR-10上的表现好的结构

  Large的结果如Table 5,PNASNet-5模型超越了之前的同体积的SOTA模型

CONCLUSION


  目前越来越多的目光聚集在神经网络架构搜索这一领域,但搜索花费一直是个巨大的问题,目前只有一些计算资源丰富的巨头公司能够玩转。论文将核心放在搜索加速方面,基于NASNet,提出渐进式的PNAS搜索策略以及通过代理函数直接预测网络的准确率,极大地优化搜索逻辑,能够在搜索到相同性能的前提下,将耗时提升21倍,十分值得参考



写作不易,未经允许不得转载~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.

你可能感兴趣的:(PNAS:渐进式神经网络搜索,准确率预测,21倍加速 | ECCV2018)