Bootstrapped MAE

摘要

基于 facebookresearch/mae: PyTorch implementation of MAE 实现了 EMA 版本与非 EMA 版本的 Bootstrapped MAE,尝试了同时预测原始像素与 bootstrapped 特征、只使用 bootstrapped 方法,并探究了不同的标准化方法、不同的学习率等的影响,在 linear evaluation 和 finetuning 中均得到了超过 baseline 的效果。

最佳结果

linear(base) finetuning(base) linear(my) finetuning(my)
64.67 81.93 68.24 85.63

encoder 架构同 DeiT-tiny,decoder 采用轻量级架构(embed_dim=192, heads=8,depth=2),调节学习率 blr=7.5e-4 得到 baseline。

实现细节与思考

linear evaluation 和 finetuning 的结果很可能不一致1。finetuning 更关注 encoder 前面几层的训练效果而不是最终输出,而且 training from scratch 就能达到 72% 的 top-1 准确率。而 linear evaluation 的结果完全依赖于 encoder 的输出。由于我们更关注模型提取高阶语义信息的能力,所以下面评价结果主要以 linear evaluation 为依据,虽然 MAE 的特征是高度线性不可分1

两个问题

对于 Bootstrapped MAE 的描述心里有两个问题:

  1. 随着训练的进行会不会丢失很多原始图像的信息?只有 MAE-1 能够得到“原始像素的反馈”,而后面的模型都是去拟合之前模型的输出,这好像是一个信息不断丢失的过程。
  2. Bootstrapped MAE 为什么会 有效?用后一个模型拟合前一个模型 encoder 的输出,不严谨的讲就像是一种“插值”,直觉上这样做和增加 decoder 层数,或者训一个比较深的网络然后取某一层的输出差不多。而且,似乎有一种矛盾:由于 MAE-1 的 encoder 的输出是包含像素级别信息的(因为 decoder 能够根据这些信息还原像素),如果 bootstrap 的过程中如果每一步都拟合的很好,训练出来的 MAE-k 始终都包含着像素的信息,占用了模型的表达能力,这是我们不希望看到的;如果拟合的不好,语义信息也会随着像素信息一起丢失。

当然后面随着对问题理解的深入以及实验数据的支撑,对这两个问题也有了一定认识。

bootstrap1

首先实现了朴素的 bootstrapped MAE,称其为 bootstrap1。借鉴了 https://github.com/rwightman/pytorch-image-models 中的 ModelEma 类,EMA 版本使用 ModelEma.update() 更新参数,非 EMA 版本的算法调用 ModelEma.set() 重置模型。loss 计算使用 MSE,只关于 unmasked patches1,使用 layer normalization。注意非 EMA 版本每次须将 optimizer 重置,否则结果很劣。

考虑到可能的信息丢失,于是让 MAE-1 训练更多的 epochs。取不同的 bootstrap 次数 k,结果如下:

k(bootstrap次数) MAE-1训练epoch数 linear evaluation
4 120 64.76
5 40 64.07
5 60 64.12
5 100 64.73
5 120 64.83
5 140 64.57
6 100 64.67

迭代 5 次, MAE-1 训练 120 epochs 时性能较好,也超过了 baseline,说明 bootstrap 的方法确实有效,但是差距不大。又尝试 EMA 版本,调节不同的衰减系数:

EMA decay linear evaluation
0.992 64.61
0.999 64.80
0.9995 64.83
0.9999 64.63

同样差距不大,与 baseline 也没有显著区别。

双 decoder

担心性能不佳是由于信息丢失导致的,我想如何能够在 MAE-k (k>1) 的训练中也加入原始的图像信息?于是我使用两个 decoder,一个decoder用于重构像素,另一个decoder用于重构前一个模型输出的特征,分别计算 loss 并按照一个权重相加2。结果如下:

loss function linear evaluation
pixel_loss 63.63
feature_loss 64.36
feature_loss+pixel_loss 65.23

提升也并不十分显著。

从随机权值开始 bootstrap

随着进一步的思考,我发现在疑问1、2中的认识是片面的。首先,不能简单地将 bootstrapped MAE 理解成“插值”或者“更深的网络”。因为 MAE-1 的输入是完整的图像,它的输出 Z 就包含着整张图的信息。如果MAE-2 能够用通过部分图像推理得到 Z,那么就可以认为它理解了图片的结构,包括高阶的语义信息。注意到,这里的 Z 其实是不需要和具体的像素信息有关的,这样迭代的模型也就不含原图像素级别的低层次信息,可以更专注于表达高阶的语义信息。也就是说如果从一个随机初始化的网络开始 bootstrap 而不是使用 vanilla MAE 作为 MAE-1,可能也会有不错的效果。

事实上,随机初始化的效果出乎意料地好,甚至比使用像素信息训练一段时间的 MAE-1 表现还要好。原因可能是随机初始化会使特征包含全局信息,而使用像素重构训练的 MAE-1 会更局限于某个patch3,因此随机初始化的表现反而更好。同时,我发现不同的初始化对应的最终模型对于同一张图片的特征表示差距也较大,这也提示着对 bootstrap 方法的初始化至关重要,如何找到比随机更好的初始化方法也许是一个值得继续研究的点。称这种方法为 bootstrap2

k(bootstrap次数) linear evaluation
4 67.00
5 67.97
6 67.85
7 68.02
8 66.53

可以发现 bootstrap2 取得了很好的结果。

尝试编写 bootstrap2 的 EMA 版本,发现 EMA 的超参数比较难调,调整了学习率以及 EMA 衰减系数,始终超不过非 EMA 版本的结果。下面是调整衰减系数的实验(blr=7.5e-4):

EMA decay linear evaluation
0.95 64.80
0.98 64.90
0.98-0.994 64.73
0.99 65.03
0.999 64.16

正则化方法

注意到 bootstrap2 最佳的 decay 比 bootstrap1 要小很多,是因为使用了标准化来缓解模型崩溃3。在 EMA 训练中 loss 绝对值非常小,测试的结果也比较差,推测是模型崩溃成了常数输出,解决方法是对特征进行标准化3。为了测试使用 encoder 的中间层输出效果是否更好,取 encoder 的最后 m 个 blocks 的输出经过 batch normalization 后的均值作为训练目标3,发现 m=1 时效果最好(实验基于非 EMA 的算法,迭代 5 次):

m linear evaluation
1 67.97
2 66.84
3 66.71

还发现在 bootstrap2 中略微调大学习率或者不使用学习率衰减效果更好:

lr linear evaluation
余弦学习率衰减 67.97
无学习率衰减 68.24

联想到上面提到的双 decoder 算法,重构像素的分支也可以理解为一种防止模型崩溃的正则化2。但加入 bootstrap2 后反而使得性能下降。推测原因还是之前分析的,使用 pixel loss 强行引入了过于细节的像素信息,占用了模型的表达能力:

loss function linear evaluation
feature_loss 68.24
feature_loss+pixel_loss 67.64

结论

所以提高 bootstrap MAE 性能的调参方法是:略微调大学习率、采用正则化、从随机权值开始 bootstrap,完全不引入像素的信息。不过这套参数并不适用于别的 decoder 架构。例如当设置 decoder 为 4 层时,甚至不能超过 baseline。这说明 bootstrap 对超参数的设置是很敏感的。

一些说明

随着实验的进行发现可能 linear evaluation 不是最合适的,因为 MAE 提取出的特征不是线性可分的1,更好的 linear evaluation 表现不一定能够说明模型有更好的表达能力。文中提出的 partial finetune 的方法会更合适。

另外,在完成最终代码之后,重新跑一开始的数据,发现结果的绝对值与之前相比会有一些波动(相对值没有变化),不清楚是改写代码的哪一步产生了影响。于是重跑之前的数据,但由于时间原因只重跑了重要的数据,因此数据表格不是非常的完整,比如没有 finetune 的部分。


  1. He, Kaiming, Xinlei Chen, Saining Xie, Yanghao Li, Piotr Dollár, and Ross Girshick. 2021. “Masked Autoencoders Are Scalable Vision Learners.” arXiv. http://arxiv.org/abs/2111.06377. ↩︎ ↩︎ ↩︎ ↩︎

  2. Dong, Xiaoyi, Jianmin Bao, Ting Zhang, Dongdong Chen, Weiming Zhang, Lu Yuan, Dong Chen, Fang Wen, and Nenghai Yu. 2022. “Bootstrapped Masked Autoencoders for Vision BERT Pretraining.” arXiv. http://arxiv.org/abs/2207.07116. ↩︎ ↩︎

  3. Baevski, Alexei, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, and Michael Auli. 2022. “Data2vec: A General Framework for Self-Supervised Learning in Speech, Vision and Language.” arXiv. http://arxiv.org/abs/2202.03555. ↩︎ ↩︎ ↩︎ ↩︎

  4. 前 100 epochs 从 0.98 线性增长到 0.99,后 100 epochs 保持 0.99。 ↩︎

你可能感兴趣的:(计算机视觉,深度学习,deep,learning,computer,vision)