论文:https://arxiv.org/pdf/2009.08453v1.pdf
代码:https://github.com/szq0214/MEAL-V2
知识蒸馏是将一个已经训练好的网络迁移到另外一个新网络,常采用teacher-student学习策略,已经被广泛应用在模型压缩和迁移学习中。这里要介绍的MEAL V2是通过知识蒸馏提升ResNet50在ImageNet上的分类准确度,MEAL V2不需要修改网络结构,也不需要其他特殊的训练策略和数据增强就可以使原始ResNet50的Top-1准确度提升至80%+,这是一个非常nice的work。
MEAL V2主要的思路是将多个模型的集成效果通过知识蒸馏迁移到一个单一网络中,整个设计非常简单,只包括三个重要的部分:teacher模型集成,KL散度loss以及一个判别器。相比其它方法,不需要特殊的trick:
MEAL V2是MEAL方法的升级版,相比之下V2版本设计上更简单,效果也更好:
采用多个teacher模型进行集成可以产生更准确的预测以更好地指导student模型训练。原始的MEAL从多个teacher中随机选择一个teacher进行蒸馏,这里是将多个teacher模型的预测概率(softmax后输出)求平均值来进行蒸馏,这实际上是一种模型集成。记teacher模型为,共有K个teacher,那么对输入,模型集成后的概率输出为;
KL散度可以用来衡量两个概率分布的差异,在训练过程中通过最小化student的概率输出和teacher模型集成后概率的KL散度来完成知识蒸馏。这里的损失函数如下:
由于上述公式展开后的第二项是teacher模型集成后概率的熵,对于训练student是一个常量,所以可以忽略,最后就剩下了交叉熵:
这里交叉熵的label是teacher模型集成后的平均概率,而不是传统训练中的one-hot/hard标签,这对于知识蒸馏是至关重要的。原始的知识蒸馏方法是只有一个teacher,但是采用的是smooth后的概率(带有温度的softmax概率)来进行训练,这里采用多模型集成更进一步。相比hard label,soft label其实信息更强,比如下图label为tobacco shop的输入图片,不同模型的输出概率其实比hard label包含了更多信息,如果训练数据有噪音,采用soft label意义就更大了。
MEAL V2采用对抗学习来防止student在训练数据上过拟合,即不让student过分学习teacher的输出,这其实是一种正则化手段。具体做法是加入一个判别器来区分student的输出和teacher的输出,这是一个二分器。二分器采用一个3层FC的子网络,其输入是softmax前的logits。这里的student网络其实充当了生成器的角色,与传统的GAN训练方式不同,这里直接把判别器loss和前面所述的CE loss直接加起来一起训练,具体做法是每个batch中,teacher的输出其GT是[0, 1],而student的输出其GT是[1,0]:
class discriminatorLoss(nn.Module):
def __init__(self, models, loss=nn.BCEWithLogitsLoss()):
super(discriminatorLoss, self).__init__()
self.models = models # 3层FC网络
self.loss = loss
def forward(self, outputs, targets):
"""
outputs和targets分别是student和teacher的logits
"""
inputs = [torch.cat((i,j),0) for i, j in zip(outputs, targets)]
inputs = torch.cat(inputs, 1)
batch_size = inputs.size(0)
target = torch.FloatTensor([[1, 0] for _ in range(batch_size//2)] + [[0, 1] for _ in range(batch_size//2)])
target = target.to(inputs[0].device)
output = self.models(inputs)
res = self.loss(output, target)
return res
这里的判别器其实只是充当一种正则化策略,对训练效果有少量提升,这是因为知识蒸馏中,一般teacher比student强大,就算强制学习,student的teacher也会有一定的差距。
论文中的teacher设置为2个,如果输入size为224,那么teacher为senet154
和resnet152_v1s
,如果输入size为380,那么teacher为efficientnet_b4_ns
和efficientnet_b4
,论文中对ResNet50做了实验,最终在ImageNet上Top-1准确度可以达到80%+:
有一点需要注意,在蒸馏时,ResNet50不是随机初始化的,而是从预训练好的ImageNet模型进行初始化,就是说student也需要一个好的初始化,如果是随机初始化可能需要更长的训练时长。
我个人觉得知识蒸馏的应用会越来越多,不管是在CV领域还是NLP领域。在最新的无监督方法研究如谷歌的SimCLRv2和Noisy Student均有知识蒸馏的身影。
MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks∗
MEAL: Multi-Model Ensemble via Adversarial Learning
Distilling the Knowledge in a Neural Network
szq0214/MEAL-V2