NAS是自动设计网络结构的重要方法,但需要耗费巨大的资源,导致不能广泛地应用,而论文提出的Efficient Neural Architecture Search(ENAS),在搜索时对子网的参数进行共享,相对于NAS有超过1000x倍加速,单卡搜索不到半天,而且性能并没有降低,十分值得参考
来源:【晓飞的算法工程笔记】 公众号
论文: Efficient Neural Architecture Search via Parameter Sharing
- 论文地址:https://arxiv.org/abs/1802.03268
Introduction
神经网络结构搜索(NAS)目前在图像分类的模型结构设计上有很大的成果,但十分耗时,主要花在搜索到的网络(child model)的训练。论文的主要工作是提出Efficient Neural Architecture Search(ENAS),强制所有的child model进行权重共享,避免从零开始训练,从而达到提高效率的目的。虽然不同的模型使用不同的权重,但从迁移学习和多任务学习的研究结果来看,将当前任务的模型A学习到的参数应用于别的任务的模型B是可行的。从实验看来,不仅共享参数是可行的,而且能带来很强的表现,实验仅用单张1080Ti,相对与NAS有1000x倍加速
Methods
NAS的搜索结果可以看作是大图中的子图,可以用单向无环图(DAG)来表示搜索空间,每个搜索的结构可以认为是图2的DAG一个子网。ENAS定义的DAG为所有子网的叠加,其中每个节点的每种计算类型都有自己的参数,当特定的计算方法激活时,参数才使用。因此,ENAS的设计允许子网进行参数共享,下面会介绍具体细节
Designing Recurrent Cells
为了设计循环单元(recurrent cell),采用节点的DAG,节点代表计算类型,边代表信息流向,ENAS的controller也是RNN,主要定义:1) 激活的边 2) 每个节点的计算类型。在NAS(Zoph 2017),循环单元的搜索空间在预先定义结构的拓扑结构(二叉树)上,仅学习每个节点的计算类型,而NAS则同时学习拓扑结构和计算类型,更灵活
为了创建循环单元,the controller RNN首先采样个block的结果,取,为当前单元输入信息(例如word embedding),为前一个time step的隐藏层输出,具体步骤如下:
- 对于node 1:controller选择一种激活方法,如图1选择了tanh,即node 1计算
- 对于node 2:controller预测一个之前的索引和激活方法,如图1选了1和ReLU,即计算
- 对于node 3:controller也预测一个之前的索引和激活方法,如图1选了2和ReLU,即计算
- 对于node 4:controller同样预测一个之前的索引和激活方法,如图1选了1和tanh,即计算
- 对于输出:将所有没被选择的节点进行平均,如图1,则将节点3和4作为输出,即输出
注意到每对节点()都有独立的参数,根据选择的索引决定使用哪个参数,因此,ENAS的所有循环单元能同一个共享参数集合。论文的搜索空间包含指数数量的配置,假设有N个节点和4种激活函数,则共有种配置
Training ENAS and Deriving Architectures
ENAS的controller为100个隐藏单元的LSTM,通过softmax分类器以自回归(autoregressive fashion)的方式进行选择的决定,上一个step的输出作为下一个step的输入embedding,controller的第一个step则接受空embedding输入。学习的参数主要有controller LSTM的参数和子网的共享权重,ENAS的训练分两个交叉的阶段,第一阶段在完整的训练集上进行共享权重学习,第二阶段训练controller LSTM的参数
-
Training the shared parameters of the child models
固定controller的策略,然后进行进行随机梯度下降(SGD)来最小化交叉熵损失函数的期望,为模型在mini-batch上的交叉熵损失,模型从采样而来
梯度的计算如公式1,上从采样来的,集合所有模型的梯度进行更新。公式1是梯度的无偏估计,但有一个很高的方差(跟NAS一样,采样的模型性能差异),而论文发现,当时,训练的效果还行
-
Training the controller parameters
固定然后更新策略参数,目标是最大化期望奖励,使用Adam优化器,梯度计算使用Williams的REINFORCE方法,加上指数滑动平均来降低方差,的计算在独立的验证集上进行,整体基本跟Zoph的NAS一样
-
Deriving Architectures
训练好的ENAS进行新模型构造,首先从训练的策略采样几个新的结构,对于每个采样的模型,计算其在验证集的minibatch上的准确率,取准确率最高的模型进行从零开始的重新训练,可以对所有采样的网络进行从零训练,但是论文的方法准确率差不多,经济效益更大
Designing Convolutional Networks
对于创建卷积网络,the controller每个decision block进行两个决定,这些决定构成卷积网络的一层:
- 选择连接的节点(previous nodes) ,可选多个,允许产生含skip connection的网络。在层k,有k-1个不同节点可以选择,共种决定
- 计算类型,共6种算子: convolutions with filter sizes 3 × 3 and 5 × 5, depthwise-separable convolutions with filter sizes 3×3 and 5×5, and max pooling and average pooling of kernel size 3 × 3,每层的每种算子都有独立的参数
做次选择产生层的网络,共种网络,在实验中,L取12
Designing Convolutional Cells
NASNet提出设计小的模块,然后堆叠成完整的网络,主要设计convolutional cell和reduction cell
使用ENAS生成convolutional cell,构建B节点的DAG来代表单元内的计算,其中node 1和node 2代表单元输入,为完整网络中前两个单元的输出,剩余的个节点,预测两个选择:1) 选择两个之前的节点作为当前节点输入 2) 选择用于两个输入的计算类型,共5种算子:identity, separable convolution with kernel size 3 × 3 and 5 × 5, and average pooling and max pooling with kernel size 3×3,然后将算子结果相加。对于,搜索过程如下:
- Node 1和Node 2为输入节点,不需要选择,定义和为节点输出
- 对于node 3:the controller选择之前的节点中的两个作为输入,为每个输入选择一种算子,在图5上左的结果,进行
- 对于node 4:与node 3一样,图5上右的结果,进行
- 最后由于只有node 4的输出还没被使用,直接将作为单元输出
对于reduction cell,可以同样地使用上面的搜索空间生成: 1) 如图5采样一个计算图 2) 将所有计算的stride改为2。这样reduction cell就能将输入缩小为1/2,controller共预测blocks
最后计算下搜索空间的复杂度,对于node i,troller选择前个节点中的两个,然后选择五种算子的两种,共种坑的单元。因为两种单元是独立的,所以搜索空间的大小最终为,对于,大约种网络
Experiments
Language Model with Penn Treebank
节点的计算做了一点修改,增加highway connections,例如修改为,其中,为elementwise乘法。搜索到的结果如图6所示,有意思的是:1) 激活方法全部为tanh或ReLU 2) 结构可能为局部最优,随机替换节点的激活函数都会造成大幅的性能下降 3) 搜索的输出是6个node的平均,与mixture of contexts(MoC)类似
单1080Ti训练了10小时,Penn Treebank上的结果如表1所示,PPL越低则性能越好,可以看到ENAS不准复杂度低,参数量也很少
Image Classification on CIFAR-10
表2的第一块为最好的分类网络DenseNet的结构,第二块为ENAS设计整个卷积网络的结果(感觉这里不应有micro search space),第三块为设计单元的结果
全网络搜索的最优结构如图7所示,达到4.23%错误率,比NAS的效果要好,大概单卡搜索7小时,相对NAS有50000x倍加速
单元搜索的结构如图8所示,单卡搜索11.5小时,,错误率为3.54%,加上CutOut增强后比NASNet要好。论文发现ENAS搜索的结构都是局部最优的,修改都会带来性能的降低,而ENAS不采样多个网络进行训练,这个给NAS带来很大性能的提升
CONCLUSION
NAS是自动设计网络结构的重要方法,但需要耗费巨大的资源,导致不能广泛地应用,而论文提出的Efficient Neural Architecture Search(ENAS),在搜索时对子网的参数进行共享,相对于NAS有超过1000x倍加速,单卡搜索不到半天,而且性能并没有降低,十分值得参考
写作不易,未经允许不得转载~
更多内容请关注 微信公众号【晓飞的算法工程笔记】