本设计思路来源于论文《Dynamic Data-Free Knowledge Distillation by Easy-to-Hard Learning Strategy》。
总体架构图如下。在常规的知识蒸馏中,一般不会考虑知识的难度先后,按照我们人类的思维,肯定是先学习容易的再学习难一点的知识(总不能小学就学高数吧哈哈)。一个模型的理想状态也应该如此。
在本论文的设计图中,可以看到Generator负责生成伪数据提供给教师模型和学生模型,但是这个Generator是受到一个随时间变化的Adversarial Scheduler调节的,这个Adversarial Scheduler的作用就是让Generator随时间生成从易到难的知识(也就是图像)提供给学生模型和教师模型。
有着一个contribution还不够,作者又在最后加上了一个Reweighting
Vector,这个向量能够修改不同样本的影响力,开始时强调更简单的样本,并逐渐包含更难的样本。具体来说,对于那些模型预测相对准确(即易于模型学习)的样本,会分配较低的权重;而对于模型预测不准确(即难以模型学习)的样本,会分配较高的权重。这样,模型初期会更多地关注易于学习的样本,随着训练的进行,逐渐增加对难样本的关注,从而实现由易到难的学习策略。
通过这两个模块,可以非常简单高效地让学生模型从易到难地学习。
那这样做的优势除了人之常理还有没有更有说服力的解释呢?
通过设置Adversarial Scheduler,模型能够在早期阶段避免过度的对抗性扰动,从而保证基础知识的有效学习,而在后期阶段逐渐增加对抗性挑战,促进模型在更复杂的情况下的性能提升。这种策略有效地平衡了学习效率和模型鲁棒性之间的关系,为模型提供了一个更平滑的学习曲线,并最终实现了更好的泛化能力。
动态调整学习难度:Reweighting Vector在学生模型的早期阶段,倾向于学习更容易的样本,并且可以避免学生模型陷入局部最小值。随着学生模型能力的提高,Reweighting Vector逐渐增加难度,引入更接近决策边界的样本,以促进学生模型的泛化能力。
防止灾难性遗忘:在知识蒸馏过程中,学生模型可能会遗忘早期学到的知识。Reweighting Vector通过在训练过程中逐渐引入新的、更难的样本,有助于学生模型在保持旧知识的同时学习。
源代码链接click here
if args.method == 'cudfkd':
if epoch > int(args.epochs * args.begin_fraction) and epoch < int(args.epochs * args.end_fraction) and args.curr_option != 'none':
synthesizer.adv += args.grad_adv
这个部分是在每个epoch结束时加入的,其目的在于动态调整synthesizer(生成器的关键部件)的adv值。
# Negative Divergence.
if self.adv > 0:
s_out = self.student(samples, l=l)
if self.adv_type == 'js':
l_js = jsdiv(s_out, t_out, T=3)
loss_adv = 1.0-torch.clamp(l_js, 0.0, 1.0)
if self.adv_type == 'kl':
mask = (s_out.max(1)[1]==t_out.max(1)[1]).float()
loss_adv = -(kldiv(s_out, t_out, reduction='none', T=3).sum(1) * mask).mean()
else:
loss_adv = torch.zeros(1).to(self.device)
loss = self.lmda_ent * ent + self.adv * loss_adv+ self.oh * loss_oh + self.act * loss_act + self.bn * loss_bn
这是synthesizer类中涉及到adv计算部分的代码,可以看到adv的值很大程度上决定了最终反向传播的loss的值,所以通过调整adv的值来控制loss进而调整Generator的参数θ是十分有效的。
if args.method == 'cudfkd':
if args.dataset == 'cifar10':
alpha = 0.0001
else:
alpha = 0.00002
lamda = datafree.datasets.utils.lambda_scheduler(args.lambda_0, global_iter, alpha=alpha)
with torch.no_grad():
g,v = datafree.datasets.utils.curr_v(l=real_loss_s, lamda=lamda, spl_type=args.curr_option.split('_')[1])
这是λ-scheduler来调控lamda的值,进而将lamda的值传入curr_v函数中,从而得到我们想要的重权向量。
具体的lambda_scheduler和curr_v函数如下:
def lambda_scheduler(lambda_0, iter, alpha=0.0001, iter_0=500000000):
if iter < iter_0:
lamda = lambda_0 + alpha * iter
else:
lamda = lambda_0 + alpha * iter_0
return lamda
lambda_scheduler很简单地对lamda进行调控,iter是每个epoch中正在蒸馏的轮次。(每个epoch可能要蒸馏上百个iter)
def curr_v(l, lamda, spl_type='hard'):
if spl_type == 'hard':
v = (l < lamda).float()
g = -lamda * (v.sum())
elif spl_type == 'soft':
v = (l < lamda).float()
v *= (1 - l / lamda)
g = 0.5 * lamda * (v * v - 2 * v).sum()
elif spl_type == 'log':
v = (1 + math.exp(-lamda)) / (1 + (l - lamda).exp())
mu = 1 + math.exp(-lamda) - v
g = (mu * mu.log() + v * (v+1e-8).log() - lamda * v)
# print(g, v.min(), v)
else:
raise NotImplementedError('Not implemented of spl type {}'.format(spl_type))
return g, v
参数中l是本次iter中的real_loss,lamda是上问的lamda,type是按照自己的需要进行选择的。(该算法的数学原理不再解释)
为什么要讲这两个插件模块?因为他们真的几乎能用到所有可用的知识蒸馏模型之中,并且实现起来简洁高效。有需要的伙伴快来食用!
码字不易,还请点赞收藏支持,谢谢!!!