基于知识蒸馏Knowledge Distillation模型压缩pytorch实现

在弄懂原理基础上,从本篇博客开始,逐步介绍基于知识蒸馏的增量学习、模型压缩的代码实现。毕竟“纸上得来终觉浅,绝知此事要躬行。”。

先从最经典的Hilton论文开始,先实现基于知识蒸馏的模型压缩。相关原理可以参考博客:https://blog.csdn.net/zhenyu_an/article/details/101646943,

既然基本原理是用一个已训练的teacher网络,去教会一个student网络,那首先需要定义这两个网络如下。这里我们采用pytorch语言,以最简单的mnist数据集为例来看看效果。

先定义student网络,由一个卷积层、池化层、全连接层构成,很简单。

class anNet(nn.Module):
    def __init__(self):
        super(anNet,self).__init__()
        self.conv1 = nn.Conv2d(1,6,3)
        self.pool1 = nn.MaxPool2d(2,1)
        self.fc3 = nn.Linear(3750,10)
    def forward(self,x):
        x = self.conv1(x)
        x = self.pool1(F.relu(x))
        x = x.view(x.size()[0],-1)
        x = self.fc3(x)
        return x
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                m.bias.data.zero_()

再定义一个teacher网络,由两个卷积、两个池化、一个全连接层组成。

class anNet_deep(nn.Module):
    def __init__(self):
        super(anNet_deep,self).__init__()
        self.conv1 = nn.Sequential(
                nn.Conv2d(1,64,3,padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU())
        self.conv2 = nn.Sequential(
                nn.Conv2d(64,64,3,1,padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU())
        self.conv3 = nn.Sequential(
                nn.Conv2d(64,128,3,1,padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU())
        self.conv4 = nn.Sequential(
                nn.Conv2d(128,128,3,1,padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU())
        self.pooling1 = nn.Sequential(nn.MaxPool2d(2,stride=2))
        self.fc = nn.Sequential(nn.Linear(6272,10))
    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pooling1(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.pooling1(x)
        x = x.view(x.size()[0],-1)
        x = self.fc(x)
        return x
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight.data, 0, 0.01)
                m.bias.data.zero_()

为了提高teacher网络的性能,在每个卷积层后面加上了BN层。通过print(sum(x.numel() for x in model.parameters()))可以计算出teacher网络和student网络的参数个数分别为:618186和37570,二者相差16倍。

我们首先在mnist数据集上分别对两种模型训练,采用相同的优化方法optimizer = optim.Adam(model.parameters(),lr = 0.001)、相同的损失函数criterion = nn.CrossEntropyLoss()和相同的epoch,teacher网络得到最佳测试正确率大约在0.989至0.991之间,student网络得到的最佳测试正确率大约在0.957至0.959之间。总体而言,在16倍的参数差距面前,student网络干不过teacher网络。把teacher网络训练好的模型保存好。

下面开始知识蒸馏中的关键代码:

知识蒸馏的关键是loss的设计,它包括普通的交叉熵loss1和建立在软目标基础上的loss2。分别如下:

# 损失函数
criterion = nn.CrossEntropyLoss()
criterion2 = nn.KLDivLoss()
# 经典损失
outputs = model(inputs.float())
loss1 = criterion(outputs, labels)
# 蒸馏损失        
teacher_outputs = teach_model(inputs.float())
T = 2
alpha = 0.5
outputs_S = F.log_softmax(outputs/T,dim=1)
outputs_T = F.softmax(teacher_outputs/T,dim=1)
loss2 = criterion2(outputs_S,outputs_T)*T*T

#综合损失结果
loss = loss1*(1-alpha) + loss2*alpha

仔细看这段代码,loss1的设计没有问题,它衡量student网络输出与标准值labels的差距,用的是交叉熵损失。

loss2是关键的蒸馏损失,它的衡量的是student网络输出与已训练好的teacher网络输出,经过软化的结果之间差距。其中outputs_S 代表student网络输出软化后结果,outputs_T 代表teacher网络输出软化后结果,二者采用的是KL散度损失函数。T和alpha是两个超参数,取法对结果影响很大,T的取法一般可以有2,10,20几种,alpha一般取0.5,0.9,0.95几种。需要留意的是这里采用的两种软化方法,student网络输出后加一个log_softmax(outputs/T),teacher网络输出后加一个softmax(参考了https://github.com/PolarisShi/distillation的写法)。这里的问题在于,pytorch源码实现的KL散度是一个阉割版本,并没有对预测结果做log处理,作者试图在这里给补上。事实上,这种写法也没有完全实现标准KL散度的公式,因为漏了log之前的值,最后写成了四不像,反倒是都用softmax的话至少与pytorch的思路一致。待重开一篇博客详细介绍pytorch中的KL散度与交叉熵。

我们把训练好的teacher网络参数导入,开始蒸馏训练,student网络最后精度可以提升到97.02%,没有预想中效果明显,可能是因为超参数的取值不合适。

完整代码和训练好的模型见个人github:https://github.com/azy1988/ML-CV/tree/master/model_distillation

 

你可能感兴趣的:(增量学习,图像分类,模型压缩,知识蒸馏,深度学习)