与标准的训练不同之处是loss部分, loss部分除了由传统的标签计算的损失之外, 额外添加了与教师模型计算的损失, 见代码中的KD_loss
。本文中采用了Distilling the Knowledge in a Neural Network中的蒸馏损失。
from torchvision.models.resnet import resnet18, resnet50
import torch
from torchvision.transforms import transforms
import torchvision.datasets as dst
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.nn as nn
resnet18_pretrain_weight = "./weights/resnet18-5c106cde.pth"
resnet50_pretrain_weight = "./weights/resnet50_cifar10.pth"
img_dir = "/data/cifar10/"
def create_data(img_dir):
dataset = dst.CIFAR10
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)
train_transform = transforms.Compose([
transforms.Pad(4, padding_mode='reflect'),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
test_transform = transforms.Compose([
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# define data loader
train_loader = torch.utils.data.DataLoader(
dataset(root=img_dir,
transform=train_transform,
train=True,
download=True),
batch_size=512, shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
dataset(root=img_dir,
transform=test_transform,
train=False,
download=True),
batch_size=512, shuffle=False, num_workers=4, pin_memory=True)
return train_loader, test_loader
def load_checkpoint(net, pth_file, exclude_fc=False):
if exclude_fc:
model_dict = net.state_dict()
pretrain_dict = torch.load(pth_file)
new_dict = {k: v for k, v in pretrain_dict.items() if 'fc' not in k}
model_dict.update(new_dict)
net.load_state_dict(model_dict, strict=True)
else:
pretrain_dict = torch.load(pth_file)
net.load_state_dict(pretrain_dict, strict=True)
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class KD_loss(nn.Module):
def __init__(self, T):
super(KD_loss, self).__init__()
self.T = T
def forward(self, out_s, out_t):
loss = F.kl_div(F.log_softmax(out_s / self.T, dim=1),
F.softmax(out_t / self.T, dim=1),
reduction='batchmean') * self.T * self.T
return loss
def test(net, test_loader):
prec1_sum = 0
prec5_sum = 0
net.eval()
for i, (img, target) in enumerate(test_loader, start=1):
# print(f"batch: {i}")
img = img.cuda()
target = target.cuda()
with torch.no_grad():
out = net(img)
prec1, prec5 = accuracy(out, target, topk=(1, 5))
prec1_sum += prec1
prec5_sum += prec5
# print(f"batch: {i}, acc1:{prec1}, acc5:{prec5}")
print(f"Acc1:{prec1_sum / (i + 1)}, Acc5: {prec5_sum / (i + 1)}")
def train(net_s, net_t, train_loader, test_loader):
# opt = Adam(filter(lambda p: p.requires_grad,net.parameters()), lr=0.0001)
opt = Adam(net_s.parameters(), lr=0.0001)
net_s.train()
net_t.eval()
for epoch in range(100):
for step, batch in enumerate(train_loader):
opt.zero_grad()
image, target = batch
image = image.cuda()
target = target.cuda()
out_s, out_t = net_s(image), net_t(image)
loss_init = CrossEntropyLoss()(out_s, target)
loss_kd = KD_loss(T=4)(out_s, out_t)
loss = loss_init + loss_kd
# prec1, prec5 = accuracy(predict, target, topk=(1, 5))
# print(f"epoch:{epoch}, step:{step}, loss:{loss.item()}, acc1: {prec1},acc5:{prec5}")
loss.backward()
opt.step()
print(f"epoch:{epoch}, loss_init: {loss_init.item()}, loss_kd: {loss_kd.item()}, loss_all:{loss.item()}")
test(net_s, test_loader)
torch.save(net_s.state_dict(), './resnet18_cifar10_kd.pth')
def main():
net_t = resnet50(num_classes=10)
net_s = resnet18(num_classes=10)
net_t = net_t.cuda()
net_s = net_s.cuda()
load_checkpoint(net_t, resnet50_pretrain_weight, exclude_fc=False)
load_checkpoint(net_s, resnet18_pretrain_weight, exclude_fc=True)
# for name, value in net.named_parameters():
# if 'fc' not in name:
# value.requires_grad = False
train_loader, test_loader = create_data(img_dir)
train(net_s, net_t, train_loader, test_loader)
# test(net, test_loader)
if __name__ == "__main__":
main()
teacher model | student model | cifar10 |
---|---|---|
- | resnet18 | 80.34/94.24 |
- | resnet50 | 83.20/94.51 |
resnet50 | resnet18 | 82.25/94.44 |
精度收敛趋势:
通过实验可以发现, 通过蒸馏的方式, resnet18的精度得到了明显的提升。
注: 本文旨在验证知识蒸馏的效果, 因此模型没有采用各种trick以及精细调优, 精度不是SOTA。