知识蒸馏: Hinton 2015年在论文《Distilling the knowledge in a neural network》中首次提出,并应用在分类任务上,大模型称为 teacher(教师模型),小模型称为 Student(学生模型),来自 Teacher 模型输出的监督信息称之为 knowledge (知识),而 student 学习迁移来自 teacher 的监督信息的过程称之为Distillation(蒸馏)。下面是知识蒸馏的入门小示例(代码可直接跑起来):
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchinfo import summary
class TeacherModel(nn.Module):
def __init__(self, in_channels=1, num_classes=10):
super(TeacherModel, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
)
self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv2 = nn.Conv2d(
in_channels=64,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
)
self.fc1 = nn.Linear(256 * 7 * 7, num_classes)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
return x
class StudentModel(nn.Module):
def __init__(self, in_channels=1, num_classes=10):
super(StudentModel, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=8,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
)
self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.conv2 = nn.Conv2d(
in_channels=8,
out_channels=16,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
)
self.fc1 = nn.Linear(16 * 7 * 7, num_classes)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
return x
def check_accuracy(loader, model, device):
num_correct = 0
num_samples = 0
model.eval()
with torch.no_grad():
for x, y in loader:
x, y = x.to(device), y.to(device)
preds = model(x)
_, predictions = preds.max(1)
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
model.train()
return (num_correct / num_samples).item()
def train_model(model, epochs):
# print(summary(model)) # teacher params: 273,802; student params: 9,098
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(epochs):
model.train()
losses = []
pbar = tqdm(train_loader, leave=False, desc=f"Epoch {epoch + 1}")
for data, labels in pbar:
data, labels = data.to(device), labels.to(device)
# forward
preds = model(data)
loss = criterion(preds, labels)
losses.append(loss.item())
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss = sum(losses) / len(losses)
acc = check_accuracy(test_loader, model, device)
print(f"Loss:{avg_loss:.2f}\tAccuracy:{acc:.2f}")
return model
def train_distillation(teacher, student, epochs, temp=7, alpha=0.3):
student_loss_fn = nn.CrossEntropyLoss()
divergence_loss_fn = nn.KLDivLoss(reduction="batchmean")
optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)
teacher.eval()
student.train()
for epoch in range(epochs):
losses = []
pbar = tqdm(train_loader, leave=False, desc=f"Epoch {epoch + 1}")
for data, labels in pbar:
data, labels = data.to(device), labels.to(device)
# forward
with torch.no_grad():
teacher_preds = teacher_model(data)
student_preds = student(data)
student_loss = student_loss_fn(student_preds, labels)
ditillation_loss = divergence_loss_fn(F.softmax(student_preds / temp, dim=1),
F.softmax(teacher_preds / temp, dim=1))
loss = alpha * student_loss + (1 - alpha) * ditillation_loss
losses.append(loss.item())
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss = sum(losses) / len(losses)
acc = check_accuracy(test_loader, student, device)
print(f"Loss:{avg_loss:.2f}\tAccuracy:{acc:.2f}")
if __name__ == '__main__':
torch.manual_seed(0) # 设置随机种子,便于复现
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True # 使用cuDNN加速卷积运算
# Load mnist train dataset
train_dataset = torchvision.datasets.MNIST(
root="data/",
train=True,
transform=transforms.ToTensor(),
download=True
)
# Load mnist test dataset
test_dataset = torchvision.datasets.MNIST(
root="data/",
train=False,
transform=transforms.ToTensor(),
download=True
)
# Create train and test dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)
epochs = 6
print("---------- teacher ----------")
teacher_model = train_model(TeacherModel().to(device), epochs)
print("---------- student ----------")
train_model(StudentModel().to(device), epochs)
print("---------- student and teacher ----------")
train_distillation(teacher_model, StudentModel().to(device), epochs)