连续学习的概念大概是在2016年以后才开始流行的,虽然今天的工业界中几乎都是使用一个或多个模型对应一个任务,但是为了让机器更像人,让机器能同时解决多个任务,同时把过去的知识运用到新的任务上,也是值得研究的课题。
使用Relu和线性层组成的全连接网络实现多个MNIST图像分类任务。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torchvision
from torchvision import datasets, transforms
import numpy as np
import os
import random
from copy import deepcopy
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(28*28, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.fc4 = nn.Linear(256, 128)
self.fc5 = nn.Linear(128, 128)
self.fc6 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(-1, 28*28)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
x = self.relu(x)
x = self.fc4(x)
x = self.relu(x)
x = self.fc5(x)
x = self.relu(x)
x = self.fc6(x)
return x
EWC的基础思想是把已经训练好的模型中的比较重要的参数用正则化项保护起来,让它们变得不那么容易被更新,从而旧的知识就不会被完全洗掉。把参数的损失函数写出则是下面的公式
L B = L ( θ ) + ∑ i λ 2 F i ( θ i − θ A , i ∗ ) 2 \mathcal{L}_B = \mathcal{L}(\theta) + \sum_{i} \frac{\lambda}{2} F_i (\theta_{i} - \theta_{A,i}^{*})^2 LB=L(θ)+i∑2λFi(θi−θA,i∗)2
我们在基础的损失函数上增加一个正则化项,每个参数受到自己在上一个任务训练完毕后,最后的参数值的约束。lambda是一般的系数,其中的F评估参数重要程度。
如果一个参数很重要,它应该有很大的F。我们有很多种方法定义F,比如我们可以计算参数的对损失函数的二阶偏导,也可以用更简单的一阶方法定义,比如下面的
F = [ ∇ log ( p ( y n ∣ x n , θ A ∗ ) ∇ log ( p ( y n ∣ x n , θ A ∗ ) T ] F = [ \nabla \log(p(y_n | x_n, \theta_{A}^{*}) \nabla \log(p(y_n | x_n, \theta_{A}^{*})^T ] F=[∇log(p(yn∣xn,θA∗)∇log(p(yn∣xn,θA∗)T]
也就是我们计算给定数据集x和终末状态参数,计算分类正确的先验概率的梯度,再对梯度求内积。这将是一个还不错的重要度估计量,实际上这个公式的推导还是有点东西的,是fisher信息矩阵的简单近似,这里不细讲,只知道如何实现即可。
既然知道了怎么计算重要度,用Pytorch实现就只需要把整个训练集丢进模型,softmax计算概率然后取对应位置的正确概率,对所有数据点求平均,反向传播计算梯度,对每个模型参数的梯度计算内积,就结束啦。
至于正则化就更简单了,每次梯度下降更新参数后,再做一个L2惩罚就行了。或者直接把L2的计算和loss func加在一起再梯度下降也是可以的。
我们做两个任务,一个是MNIST手写数字识别,另一个是USPS手写数字识别,在此之前我们需要把USPS的16x16转为28x28.
MNIST_transform = transforms.Compose([
transforms.ToTensor(),
])
USPS_transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
])
mnist = torchvision.datasets.MNIST(
root='C:/Users/Administrator/DL',
train=True,
transform = MNIST_transform
)
usps = datasets.USPS(
root = 'C:/Users/Administrator/DL',
transform=USPS_transform,
train = True,
download=True
)
batch_size = 100
mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
usps_loader = torch.utils.data.DataLoader(dataset=usps,
batch_size=batch_size,
shuffle=True)
我们需要打印出EWC的成效,在使用EWC后,即使学习了usps也不会让mnist的正确率下降太多。首先进行普通的训练,看一看如果不进行fine tune,在学习了一个任务后,模型还能不能在另一个任务上表现的好。
def normal_train(model, optimizer, loader, summary_epochs):
model.train()
model.zero_grad()
loss_func = nn.CrossEntropyLoss()
losses = []
loss = 0.0
for epoch in range(summary_epochs):
for step, (imgs, labels) in enumerate(loader):
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
ce_loss = loss_func(outputs, labels)
optimizer.zero_grad()
ce_loss.backward()
optimizer.step()
loss += ce_loss.item()
if (step + 1) % 20 == 0:
loss = loss / 20
print ("\r", "Epoch {}, step {}, loss: {:.3f} ".format(epoch + 1,step+1,loss), end=" ")
losses.append(loss)
loss = 0.0
return losses
def verify(model, loader):
with torch.no_grad():
correct = 0
total = 0
for images, labels in loader:
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on given dataset: {} %'.format(100 * correct / total))
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
normal_train(model,optimizer,mnist_loader,10)
verify(model,mnist_loader)
normal_train(model,optimizer,usps_loader,10)
verify(model,usps_loader)
verify(model,mnist_loader)
Epoch 10, step 600, loss: 0.018
Accuracy of the network on given dataset: 99.51333333333334 %
Epoch 10, step 60, loss: 0.018
Accuracy of the network on given dataset: 99.21821423673022 %
Accuracy of the network on given dataset: 77.80166666666666 %
实现EWC的train的时候,先计算模型的参数重要度矩阵,用重要度做参数,在原loss上增加L2正则。
def ewc_train(model, optimizer, previous_loader, loader, summary_epochs, lambda_ewc):
# 计算重要度矩阵
params = {n: p for n, p in model.named_parameters() if p.requires_grad}# 模型的所有参数
_means = {} # 初始化要把参数限制在的参数域
for n, p in params.items():
_means[n] = p.clone().detach()
precision_matrices = {} #重要度
for n, p in params.items():
precision_matrices[n] = p.clone().detach().fill_(0) #取zeros_like
model.eval()
for data, labels in previous_loader:
model.zero_grad()
data, labels = data.to(device),labels.to(device)
output = model(data)
############ 核心代码 #############
loss = F.nll_loss(F.log_softmax(output, dim=1), labels)
# 计算labels对应的(正确分类的)对数概率,并把它作为loss func衡量参数重要度
loss.backward() # 反向传播计算导数
for n, p in model.named_parameters():
precision_matrices[n].data += p.grad.data ** 2 / len(previous_loader)
########### 计算对数概率的导数,然后反向传播计算梯度,以梯度的平方作为重要度 ########
model.train()
model.zero_grad()
loss_func = nn.CrossEntropyLoss()
losses = []
loss = 0.0
for epoch in range(summary_epochs):
for step, (imgs, labels) in enumerate(loader):
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
ce_loss = loss_func(outputs, labels)
total_loss = ce_loss
# 额外计算EWC的L2 loss
ewc_loss = 0
for n, p in model.named_parameters():
_loss = precision_matrices[n] * (p - _means[n]) ** 2
ewc_loss += _loss.sum()
total_loss += lambda_ewc * ewc_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
loss += total_loss.item()
if (step + 1) % 20 == 0:
loss = loss / 20
print ("\r", "Epoch {}, step {}, loss: {:.3f} ".format(epoch + 1,step+1,loss), end=" ")
losses.append(loss)
loss = 0.0
return losses
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
normal_train(model,optimizer,mnist_loader,10)
verify(model,mnist_loader)
ewc_train(model,optimizer,mnist_loader,usps_loader,10,350)
verify(model,usps_loader)
verify(model,mnist_loader)
Epoch 10, step 600, loss: 0.039
Accuracy of the network on given dataset: 99.67166666666667 %
Epoch 10, step 60, loss: 0.017
Accuracy of the network on given dataset: 99.71197366616376 %
Accuracy of the network on given dataset: 91.02666666666667 %
可以看见,使用EWC前后,USPS的训练正确率并没有下降,但是却使用正则化保住了前面的MNIST任务的正确率。
Life Long Learning是一种非常新奇的技术,和工业界中在使用被用来解决问题的强化学习和迁移学习技术不太一样;虽然这种连续学习技术也是为了把多个任务融会贯通,但是LLL更偏重于如何让模型学会多种知识且不忘记之前的知识,用一个模型解决多种问题。
虽然这种技术现在看来是一种很没用的技术,但是为了让机器更像人,实现强人工智能以及Learn to learn的梦想,这种长期学习的技术又是必不可少的。