我们提出了基于多层感知器的图像分类体系结构ResMLP。它是一种简单的残差网络,它可以(i)替代一个线性层,其中图像patch在各个通道之间独立而相同地相互作用,以及(ii)替代一个两层前馈网络,其中每个通道在每个patch之间独立地相互作用。当使用大量数据增强和选择性蒸馏的现代训练策略进行训练时,它在ImageNet上获得了惊人的准确性/复杂性。我们将基于Timm库和预先训练过的模型共享我们的代码。
最近,transformer架构[52](从其最初在自然语言处理中的使用进行了调整,只做了很小的更改,在使用足够大的数据[13]进行预训练时,在ImageNet-1k[43]上实现了与当前水平相当的性能。回顾一下,这一成就是发展的又一步:与人工设计的前CNN方法相比,卷积神经网络删除了许多手工选择,将硬连接功能的范式转移到手工设计的架构选择。vision transformer避免了对卷积架构和平移不变性的固有假设。
最近这些基于transformer的工作表明,更长的训练计划、更多的参数、更多的数据[13]和/或更多的正则化[49],足以恢复像ImageNet分类这样复杂任务的重要先验。请参见第4节中我们对相关工作的讨论。这与最近的研究一致[2,12],即更好地将体系结构的好处与训练方法的好处分开。
在本文中,我们进一步推动这一趋势,并提出残差多层感知器(ResMLP):一种纯粹基于多层感知器(MLP)的图像分类体系结构。我们在图1中概述了我们的体系结构,并在第2节中进一步详细说明。它的目标是简单的:以拉平的patch作为输入,用线性层投影它们,然后用两个residual操作依次对它们进行更新:(i)一个简单的线性层,提供patch之间的交互作用,独立应用于所有通道;(ii)具有单一隐藏层的MLP,该MLP独立应用于所有patch。在网络的最后,将patch进行平均池处理,并将其送入线性分类器。
这个架构受到vision transformer(ViT)[13]的强烈启发,但它在几个方面更简单:我们不使用任何形式的注意,只有线性层和GELU非线性。由于我们的体系结构训练起来要比transformer稳定得多,所以我们不需要批特定的或跨通道的标准化,如Batch-Norm,GroupNorm 和 LayerNorm。我们的训练程序主要遵循最初为DeiT[49]和CaiT[50]引入的程序。
由于其线性特性,我们的模型中的patch相互作用可以很容易地可视化和解释。虽然在第一层学习到的交互模式非常类似于一个小型卷积过滤器,但我们观察到在更深层次的补丁之间更微妙的交互。这包括某种形式的轴向过滤,以及网络早期的长期交互作用。总之,在本文中,我们证明了
尽管简单,残差多层感知器可以达到惊人的准确性/复杂性,与ImageNet-1k训练,而不需要基于批处理或通道统计的标准化;
这些模型显著受益于蒸馏方法[49];
•由于它的设计,patch embeddings简单地通过一个线性层“沟通”,我们可以观察到网络跨层学习的空间交互类型。
我们的模型,如图1所示,是受ViT模型的启发,采用了path flattening结构。我们进行了极端的简化。关于ViT架构的更多细节,请参阅Dosovitskiyet al.[13]。
整个ResMLP体系结构。我们的模型,用ResMLP表示,以网格N×N、非重叠patch作为输入,其中N通常等于16。然后,这些小块被独立地穿过一层线性层,形成一组嵌入。
连接(embedding)的集合被输入到一个残差多层感知器层序列,以产生一组
维输出embedding。然后将这些输出embedding值平均为d维向量来表示图像,将其输入线性分类器以预测与图像相关的标签。训练使用交叉熵损失。
残差多感知器层。我们的网络是一系列具有相同结构的层:一个线性子层后面跟着一个前馈子层。与Transformer层类似,每个子层都与跳过连接[19]并行。我们没有应用层归一化[1],因为在使用下面的仿射变换时,没有它的训练是稳定的:
……(1)
其中α和β是可学习的向量。这个操作只是按输入组件进行缩放和移动。此外,它在推理时没有代价,因为它可以在相邻区域融合线性层。注意,当写入aff (X)时,该操作独立应用于X的每一列。虽然类似于BatchNorm[24]和Layer Normalization[1],但aff操作符不依赖于任何批统计。因此,它更接近于最近的LayerScale方法[50],这提高了在初始化α到一个小值时深度变压器的优化。请注意,LayerScale没有偏倚项。
我们对每个残块应用这个变换两次。作为一个预规范化,Aff代替了normalization,并避免使用channel-wise统计。这里,我们初始化α=1, β=0。作为残差块的后处理,Aff实现了LayerScale,因此我们遵循与[50]相同的小值初始化α进行后归一化。这两种转换在推理时被集成到线性层中。
最后,我们遵循与Transformer中相同的前馈子层结构,仅用GELU[21]替换ReLUnon-linearity。
总的来说,我们的多感知器层获取一组叠在大小、d维的输入特征,叠在一个矩阵X中;并输出一组
、d维的输出特征,叠在一个矩阵Y中,并进行以下转换:公式(2)(3)
其中,A、B和C都是层主要的可学习参数。矩阵A的维度是,即这个子层混合了来自所有位置的信息,而前馈子层在每个位置工作。因此,Z和X、Y的尺寸相同。最后,B、C矩阵与Transformer层中相同的尺寸,即
和
。
与Transformer层的主要区别是,我们用式(2)中定义的线性交互来代替自注意。而自注意计算其他特征的凸组合,这些特征与数据相关,式(2)中的线性交互层使用不依赖数据的学习系数计算一般的线性组合。与具有局部支持和跨空间共享权的卷积层相比,我们的线性patch交互层提供了全局支持而不共享权重,而且它跨通道独立应用。
与Vision Transformer的关系。我们的模型可以看作是Dosovitskiyet al.[13]的ViT模型的急剧简化。我们与这个模型的不同之处如下:
class -MLP:带有类embedding的MLP。作为平均池的替代方案,我们还试验了CaiT[50]中引入的类注意的适应性。它由两个具有与Transformer相同结构的层组成,但其中仅基于冻结patch嵌入更新类token。我们将这种方法转化为我们的网络,用简单的线性层代替类和petch嵌入之间基于注意力的交互。这提高了性能,但增加了一些参数和计算成本。我们将这种池变体称为“class-MLP”。
在本节中,我们将展示用于图像分类的ResMLP体系结构的实验结果。我们还研究了ResMLP体系结构中不同组件的影响。
数据集。我们在ImageNet-1k数据集[43]上训练模型,该数据集包含120万张图像,平均分布在1000个对象类别上。在这个基准没有可用的测试集的情况下,我们遵循社区中的标准实践,报告验证集的性能。这并不理想,因为验证集最初设计用于选择超参数。由于性能的提高可能不是由于更好的建模,而是由于更好地选择超参数,因此对这一集的方法进行比较可能还不够有说服力。为了降低这种风险,我们报告了两个不同版本的ImageNet的额外结果,这两个版本构建了不同的验证和测试集,即ImageNet-real[3]和ImageNet-v2[42]数据集。我们的超参数主要采用Touvron等[49,50]。
训练范例。在我们的实验中,我们考虑了两种训练范式:
Hyper-parameter设置。在有监督学习的情况下,我们用Lamb优化器[55]训练我们的网络,学习率,权值衰减0.2。我们按照Touvronet al.[50]为CaiT提出的现成参数,将LayerScale参数初始化为深度函数。其余超参数遵循DeiT[49]中使用的默认设置。对于知识蒸馏范式,我们使用与DeiT相同的RegNety16GF[41],并使用相同的训练计划。
在本节中,我们将我们的体系结构与ImageNet上具有同等规模和吞吐量的标准神经网络进行比较。
与Transformer和卷积神经网络在监督下的比较。在表1中,我们比较了不同卷积和Transformer架构的ResMLP。为了完整起见,我们还报告了仅在ImageNet上训练的模型获得的最佳发布数字。正如预期的那样,在准确性、FLOPs和吞吐量之间的权衡方面,ResMLP不如卷积网络或transformer好。然而,他们的准确性非常令人鼓舞。实际上,我们将它们与经过多年研究和仔细优化的体系结构进行了比较。总的来说,我们的结果表明,由层设计施加的结构约束不会对性能产生重大影响,特别是当训练模型具有足够的数据和训练和正规化方面的现代进步时。
通过知识精馏改进模型收敛性。我们还研究了我们的模型,当训练遵循Touvronet al.[49]的知识蒸馏范式。在他们的工作中,作者展示了通过从高效网络中提取ViT模型来训练它的影响。在这个实验中,我们探讨了ResMLP是否也从这个过程中受益,并在表2中总结了我们的结果。我们观察到类似于DeiT模型,ResMLP从蒸馏中获益良多。这一结果与d 'Ascoliet al.[11]的观测结果一致,后者使用卷积网络初始化前馈网络。尽管我们的设置在规模上与他们不同,但在ImageNet上,前馈网络的过拟合问题仍然存在。从蒸馏中获得的额外正则化可能是这种改进的一种解释。
在本文中,我们展示了一种简单的残差结构,其残差块由一个隐藏层前馈网络和一个线性patch交互层组成,在ImageNet分类基准上取得了出乎意料的高性能,假设我们采用了现代的培训策略,比如最近为基于Transformer的体系结构引入的培训策略。由于它们的结构简单,线性层作为patch之间的主要通信手段,我们可以直观地看到这个简单的MLP学习到的过滤器。虽然有些层类似于卷积过滤器,但我们也可以早在网络的第二层就观察到稀疏的长距离交互。我们希望我们的空间无先验模型将有助于进一步了解具有较少先验的网络学习,并潜在地指导未来网络的设计选择,而不是大多数卷积神经网络所采用的金字塔设计。