今天看一篇沈老师他们的工作,是知识蒸馏(knowledge distillation)相关
20200326 今天再看一遍这个文章,以前没有关注他的pair-wise distillation
Previous knowledge distillation strategies used for dense prediction tasks often directly borrow the distillation scheme for image classification and perform knowledge distillation for each pixel separately, leading to sub-optimal performance
文章摘要说,以前的KD都是对每一个像素学习知识,会得到一个次优的解。他说的是对于dense prediction
Here we propose to distill structured knowledge from large networks to small networks, taking into account the fact that dense prediction is a structured prediction problem.
由于dense prediction
就是一个结构预测的问题,所以提出了一个‘蒸馏结构知识’的方法。有两种结构蒸馏方案,他管以前的KD叫做pixel-wise distillation
pair-wise distillation
The pair-wise distillation scheme is motivated by the widely-studied pair-wise Markov random field framework. 引文23
holistic distillation
The holistic distillation scheme aims to align higher-order consistencies
Specifically, we study two structured distillation schemes:
i) pair-wise distillation that distills the pairwise similarities by building astatic graph
and ii) holistic distillation that usesadversarial training
to distill holistic knowledge.
Dense Prediction
Dense prediction is a category of fundamental problems in computer vision, which learns a
mapping from input objects to complex output structures
, including semantic segmentation, depth estimation and object detection, among many others.
将输入映射为复杂的结构输出,那么他的这种结构蒸馏好像不是我们图像恢复需要的?(思想好像可以照搬,但是他的设计可能更关注dense structure
除了holistic loss
a classifier
计算的分割map???上采样到 W × H W\times H W×H作为分割结果。翻译一下:
pair-wise KD
class CriterionPairWiseforWholeFeatAfterPool(nn.Module):
def __init__(self, scale, feat_ind):
'''inter pair-wise loss from inter feature maps'''
super(CriterionPairWiseforWholeFeatAfterPool, self).__init__()
self.criterion = sim_dis_compute
self.feat_ind = feat_ind #-5
self.scale = scale #0.5
def forward(self, preds_S, preds_T):
:preds_S 学生网络feature
:preds_T 教师网络feature
feat_S = preds_S[self.feat_ind]
feat_T = preds_T[self.feat_ind]
total_w, total_h = feat_T.shape[2], feat_T.shape[3]
#patch_w, patch_h scale即beta聚合后feature大小,按照文章说法不是avgpooling,然后余弦距离
#这里scale还不是我想的将featurescale的大小,而是pool的大小,即beta aggregate后feature大小为 int(1/scale)-(2)
#这么小吗,一个feature只只剩下1/scale x 1/scale (4)个node,那alpha有什么用。。。
patch_w, patch_h = int(total_w*self.scale), int(total_h*self.scale)
#maxpooling ceil_mode 是否将kernel以外不满足大小的部分也做maxpooling,False则舍弃
maxpool = nn.MaxPool2d(kernel_size=(patch_w, patch_h), stride=(patch_w, patch_h), padding=0, ceil_mode=True) # change
loss = self.criterion(maxpool(feat_S), maxpool(feat_T))
return loss
def L2(f_):
-> c通道二范数
return (((f_**2).sum(dim=1))**0.5).reshape(f_.shape[0],1,f_.shape[2],f_.shape[3]) + 1e-8
def similarity(feat):
:feat feature map after pooling
-> nodexnode similarity matrix
feat = feat.float()
tmp = L2(feat).detach()
feat = feat/tmp
#reshape -> nxcxhw
feat = feat.reshape(feat.shape[0],feat.shape[1],-1)
#这样返回nodexnode的similarity feature,node所在通道和其他所有node通道之间的向量点积
return torch.einsum('icm,icn->imn', [feat, feat])
def sim_dis_compute(f_S, f_T):
:f_S, f_T pooling后的feature
sim_err = ((similarity(f_T) - similarity(f_S))**2)/((f_T.shape[-1]*f_T.shape[-2])**2)/f_T.shape[0]
sim_dis = sim_err.sum()
return sim_dis
beta=2 alpha=global
是语义分割效果最好,并且在alpha和beta之间优先增大beta的值减小计算量Holistic distillation
conditional GAN
,由于离散的JS散度(两个分布相差过大,KL没意义,JS是常数),使用Wasserstein distance or Earth Mover distance
使用self-attention residual block
,self-attention和residual block的位置和数量见论文conditional GAN
Wasserstein distance
for p in discriminator.parameters():
p.data.clamp_(-opt.clip_value, opt.clip_value)
We further add a pooling layer to pool the holistic embedding into a score
pooling as score
The Lipschitz requirement is satisfied by the gradient penalty
模块class selfAttn(nn.Module):
def __init__(self, dim):
super(selfAttn, self).__init__()
self.query_conv = nn.Conv2d(dim, dim//8, 1)
self.key_conv = nn.Conv2d(dim, dim//8, 1)
self.value_conv = nn.Conv2d(dim, dim, 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
#这里dim=-1其实是每一行的所有列进行softmax, 所以归一化表现在每一行上
def forward(self, x):
n, c, h, w = x.shape
proj_query = self.query_conv(x).view(n, -1, h*w).permute(0, 2, 1) #nxc'xhw -> nxhwxc'
proj_key = self.key_conv(x).view(n, -1, h*w) #nxc'xhw
energy = torch.bmm(proj_query, proj_key) #nxwhxwh
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(n, -1, h*w) #nxcxhw
out = torch.bmm(proj_value, attention.permute(0, 2, 1)) #nxcxhw,
out = out.view(n, c, h, w)
out = self.gamma*out + x #resnet skip connection
print(proj_query.shape, proj_key.shape, energy.shape, attention.shape, proj_key.shape)
return out, attention
都找不到low-level知识蒸馏的文章,在github awesome-knowledge-distillation项目中只找到一篇超分相关的,还是西电的2333,看一下吧
recently, Zhang et al. also introduced spatial attention (non-local module) into the residual block and then constructed residual non-local attention network (RNAN) [37] for various image restoration tasks.
spatial attention for various image restoration tasks
retains partial information
,一部分feature treats other features
,aggregating features distilled
这个就叫做蒸馏?contrast-aware channel attention layer, specifically related to the low-level vision task
channel split
?就是将channel分为两部分(64=16+48),前面的层表示refined features
CCA做了上面这样一件事,按channel统计均值和方差作为attention相加,总感觉这样的加法或者乘法没有什么道理,并不能说服我,但是他是网络学习出来的特征,可能代表一些东西。上面公式计算出来的结果是 c × 1 × 1 c\times1\times1 c×1×1 ?
Adaptive cropping stategy ACS
好,到这里我们大概了解了知识蒸馏的做法,人们将他归为transfer learning迁移学习,就是将大网络的先验知识传送给小网络,那么这样做为什么可行呢??然我们从头开始学习**Knowlledge Distillation