知识蒸馏(Knowledge Distillation)实例教程

1 实验介绍

  1. 首先分别测试了resnet18 和resnet50的在cifar10上的精度结果,预训练权重为torchvision中的resnet18和resnet50的权重, 修改最后的fc层, 在cifar10数据集上进行finetune。
  2. 保持其他条件不变, 用resnet50 作为教师模型训练resnet18, 并测试精度。

2 代码实现

与标准的训练不同之处是loss部分, loss部分除了由传统的标签计算的损失之外, 额外添加了与教师模型计算的损失, 见代码中的KD_loss。本文中采用了Distilling the Knowledge in a Neural Network中的蒸馏损失。

from torchvision.models.resnet import resnet18, resnet50
import torch
from torchvision.transforms import transforms
import torchvision.datasets as dst
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.nn as nn

resnet18_pretrain_weight = "./weights/resnet18-5c106cde.pth"
resnet50_pretrain_weight = "./weights/resnet50_cifar10.pth"
img_dir = "/data/cifar10/"


def create_data(img_dir):
    dataset = dst.CIFAR10
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)
    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(
        dataset(root=img_dir,
                transform=train_transform,
                train=True,
                download=True),
        batch_size=512, shuffle=True, num_workers=4, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        dataset(root=img_dir,
                transform=test_transform,
                train=False,
                download=True),
        batch_size=512, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, test_loader


def load_checkpoint(net, pth_file, exclude_fc=False):
    if exclude_fc:
        model_dict = net.state_dict()
        pretrain_dict = torch.load(pth_file)
        new_dict = {k: v for k, v in pretrain_dict.items() if 'fc' not in k}
        model_dict.update(new_dict)
        net.load_state_dict(model_dict, strict=True)
    else:
        pretrain_dict = torch.load(pth_file)
        net.load_state_dict(pretrain_dict, strict=True)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class KD_loss(nn.Module):
    def __init__(self, T):
        super(KD_loss, self).__init__()
        self.T = T

    def forward(self, out_s, out_t):
        loss = F.kl_div(F.log_softmax(out_s / self.T, dim=1),
                        F.softmax(out_t / self.T, dim=1),
                        reduction='batchmean') * self.T * self.T

        return loss


def test(net, test_loader):
    prec1_sum = 0
    prec5_sum = 0
    net.eval()
    for i, (img, target) in enumerate(test_loader, start=1):
        # print(f"batch: {i}")
        img = img.cuda()
        target = target.cuda()

        with torch.no_grad():
            out = net(img)
        prec1, prec5 = accuracy(out, target, topk=(1, 5))
        prec1_sum += prec1
        prec5_sum += prec5
        # print(f"batch: {i}, acc1:{prec1}, acc5:{prec5}")
    print(f"Acc1:{prec1_sum / (i + 1)}, Acc5: {prec5_sum / (i + 1)}")


def train(net_s, net_t, train_loader, test_loader):
    # opt = Adam(filter(lambda p: p.requires_grad,net.parameters()), lr=0.0001)
    opt = Adam(net_s.parameters(), lr=0.0001)
    net_s.train()
    net_t.eval()
    for epoch in range(100):
        for step, batch in enumerate(train_loader):
            opt.zero_grad()
            image, target = batch
            image = image.cuda()
            target = target.cuda()
            out_s, out_t = net_s(image), net_t(image)
            loss_init = CrossEntropyLoss()(out_s, target)
            loss_kd = KD_loss(T=4)(out_s, out_t)
            loss = loss_init + loss_kd
            # prec1, prec5 = accuracy(predict, target, topk=(1, 5))
            # print(f"epoch:{epoch}, step:{step}, loss:{loss.item()}, acc1: {prec1},acc5:{prec5}")
            loss.backward()
            opt.step()
        print(f"epoch:{epoch}, loss_init: {loss_init.item()}, loss_kd: {loss_kd.item()}, loss_all:{loss.item()}")
        test(net_s, test_loader)

    torch.save(net_s.state_dict(), './resnet18_cifar10_kd.pth')


def main():
    net_t = resnet50(num_classes=10)
    net_s = resnet18(num_classes=10)
    net_t = net_t.cuda()
    net_s = net_s.cuda()
    load_checkpoint(net_t, resnet50_pretrain_weight, exclude_fc=False)
    load_checkpoint(net_s, resnet18_pretrain_weight, exclude_fc=True)
    # for name, value in net.named_parameters():
    #     if 'fc' not in name:
    #         value.requires_grad = False

    train_loader, test_loader = create_data(img_dir)
    train(net_s, net_t, train_loader, test_loader)
    # test(net, test_loader)


if __name__ == "__main__":
    main()

3 实验结果

teacher model student model cifar10
- resnet18 80.34/94.24
- resnet50 83.20/94.51
resnet50 resnet18 82.25/94.44

精度收敛趋势:
知识蒸馏(Knowledge Distillation)实例教程_第1张图片
通过实验可以发现, 通过蒸馏的方式, resnet18的精度得到了明显的提升。

注: 本文旨在验证知识蒸馏的效果, 因此模型没有采用各种trick以及精细调优, 精度不是SOTA。

你可能感兴趣的:(模型压缩,深度学习,人工智能,ai,机器学习)