【论文解读】终生学习LLL-正则化方法:Memory Aware Synapses

一、简介

AMS可以在无监督和在线学习中计算网络参数的重要性。给与新数据可以计算出网络参数的特征重要性,基于模型数据的L2范数的平方,其参数的梯度反应新数据预测的敏感性,将其作为权重,让其保守变化,提高模型的泛化能力和减少模型的复杂度。
首次将基于未标记数据的参数重要性调整网络需要(不要)忘记的内容的能力,这可能会因测试条件而异。

二、重要贡献

  • 提出 AMS
  • 我们展示了MAS的局部变体是如何与Hebbian学习计划联系在一起的
  • 方法达到了 SOTA, 方法同样适用于对象识别和预测(输出为embedding而不是softmax)

3. MAS 算法

3.1 参数重要性计算

MAS中损失函数如下, 模型在学习任务B之前学习任务A。

L B = L ( θ ) + ∑ i λ 2 Ω i ( θ i − θ A , i ∗ ) 2 \mathcal{L}_B = \mathcal{L}(\theta) + \sum_{i} \frac{\lambda}{2} \Omega_i (\theta_{i} - \theta_{A,i}^{*})^2 LB=L(θ)+i2λΩi(θiθA,i)2

相对EWC来说, 在损失函数中 F i F_i Fi Ω i \Omega_i Ωi 替代, Ω i \Omega_i Ωi 计算方法如下

Ω i = ∣ ∣ ∂ ℓ 2 2 ( M ( x k ; θ ) ) ∂ θ i ∣ ∣ \Omega_i = || \frac{\partial \ell_2^2(M(x_k; \theta))}{\partial \theta_i} || Ωi=∣∣θi22(M(xk;θ))∣∣

x k x_k xk 是之前任务中的样本数据。所以 Ω \Omega Ω是所学习的网络输出的平方L2范数的梯度。 目的:为了在其梯度中寻找对新任务预测敏感的参数,让其保守变化。有效防止与先前任务相关的重要知识被覆盖.
论文中提出的方法是通过从模型的每一层获取平方L2范数输出的局部版本。 下面实现全局版本, 仅用通过模型的最后一层获取输出。

3.2 python实现

具体的应用案例可以看笔者的github: AMS_Train.ipynb

class mas(object):
    def __init__(self, model, dataloader, device, prev_guards=[None]):
        self.model = model 
        self.dataloader = dataloader
        # 提取模型全部参数
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} 
        # 参数初始化
        self.p_old = {} 
        self.device = device
        # 保存之前的 guards
        self.previous_guards_list = prev_guards
        # 生成 Omega(Ω) 矩阵
        self._precision_matrices = self._calculate_importance() 
        for n, p in self.params.items():
            # 保留原始数据 - 保存为不可导
            self.p_old[n] = p.clone().detach()

    def _calculate_importance(self):
        out = {}
        # 初始化 Omega(Ω) 矩阵(全部填充0)并加上之前的 guards
        for n, p in self.params.items():
            out[n] = p.clone().detach().fill_(0)
            for prev_guard in self.previous_guards_list:
                if prev_guard:
                    out[n] += prev_guard[n]

        self.model.eval()
        if self.dataloader is not None:
            number_data = len(self.dataloader)
            for x, y in self.dataloader:
                self.model.zero_grad()
                x, y = x.to(self.device), y.to(self.device)
                pred = self.model(x)
                # 生成 Omega(Ω) 矩阵. 
                # 网络输出 L2范数平方的梯度
                loss = torch.mean(torch.sum(pred ** 2, axis=1))
                loss.backward()
                for n, p in self.model.named_parameters():
                    out[n].data += torch.sqrt(p.grad.data ** 2) / number_data
        out = {n: p for n, p in out.items()}
        return out

    def penalty(self, model: nn.Module):
        loss = 0
        for n, p in model.named_parameters():
            # 最终的正则项 =   Omega(Ω)权重 * 权重变化平方((p - self.p_old[n]) ** 2) 
            _loss = self._precision_matrices[n] * (p - self.p_old[n]) ** 2
            loss += _loss.sum()
        return loss
  
    def update(self, model):
        return 

参考:
Memory Aware Synapses

你可能感兴趣的:(笔记,深度学习,pytorch,人工智能)