模型压缩(4)——知识蒸馏

对于大的数据集,小模型往往很难获得较高的精度;知识蒸馏则是使用大模型指导小模型,使小模型学到大模型包含的知识,从而得到更高的精度。原理网上很多,主要是给softmax加温度实现,这里不赘述,直接讲简单实现。

1. 训练教师模型

教师模型用于仅用于指导小模型,不参与部署,因此在条件允许的情况下,可以选择很大的模型,举个例子:

from torch import nn
from torchvision import models
from torchsummary import summary
# 使用预训练的resnet152, 加载预训练权重
model = models.resnet152(pretrained=True)
# 修改全连接层,改为自己的预测类别数
model.fc = nn.Linear(model.fc.in_features, 10)
summary(model, (3, 224, 224), device="cpu")

然后使用上述模型在自己的数据集上进行微调即可。

2. 训练学生模型

此时的学生模型可以选择很小的模型,使用知识蒸馏比常规训练主要多了以下两个步骤:

2.1 定义损失函数

criteon = nn.CrossEntropyLoss()
t_loss = nn.KLDivLoss(reduction="batchmean")

2.2 加载教师模型

if distilled:
    t_net = models.resnet152().to(device)
    try:
        t_net.load_state_dict(torch.load("./model/resnet152.pth"))
        print("successful")
    except:
        print("failed")

2.3  计算loss

# T是温度,用于削弱softmax,可以自己调
if distilled:
    loss = 0.7 * criteon(logits, target.to(device)) + 0.3 * t_loss(F.log_softmax(logits/T, dim=1),
                                                                      F.softmax(t_logits/T, dim=1))
else:
    loss = criteon(logits, target.to(device))

虽然看原理觉得很有趣,但用mnist、cifar10、cifar100测试并没有感觉有什么用啊。

3. 代码链接

Model-Compression/Distillation at main · liuweixue001/Model-Compression (github.com)

你可能感兴趣的:(Pytorch,模型压缩,pytorch,计算机视觉,深度学习)