Distilling Holistic Knowledge with Graph Neural Networks论文解读

Distilling Holistic Knowledge with Graph Neural Networks论文解读_第1张图片
这是一篇ICCV2021的文章,提出了一种新的知识蒸馏方式(Holistic Knowledge Distillation)
Figure 1为Individual、Relational、Holistic Knowledge Distillation三种不同的知识蒸馏方式的区别.这里根据Relational Knowledge Distillation解读以及Relational Knowledge Distillation简单介绍一下这几种知识蒸馏方式的区别:
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第2张图片
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第3张图片
先根据Relational Knowledge Distillation论文中的图来解释传统的知识蒸馏以及Relational Knowledge Distillation。由上图可以看出传统KD是对单张图片分别根据学生和老师模型提取特征向量,并通过KL散度以及其他方法来计算学生和老师模型输出的差异,所以这里point to point就很好理解了。Relational Knowledge Distillation在传统KD的基础上,将多张图片的特征向量通过distance-wise (second-order) and angle-wise (third-order) distillation losses合在一起进行学习。
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第4张图片
而本文认为单一的提取个体的信息和单一的提取个体间的信息是不够的,因此提出了Holistic Knowledge Distillation,整合了传统KD和Relational Knowledge Distillation。


给定一个从K类数据集中采样得到的 X = { x 1 , x 2 , . . . x N } X=\{x_1,x_2,...x_N\} X={ x1,x2,...xN},带有相应的标签 Y = { y 1 , y 2 , . . . y N } Y=\{y_1,y_2,...y_N\} Y={ y1,y2,...yN},其中N表示采样的个数。 W t W^t Wt W s W^s Ws分别表示固定参数的优化好的教师模型和可训练参数的学生模型,老师模型和学生模型的特征表示(经常用于Relational Knowledge Distillation)分别为 f t ∈ R d t f^t \in R^{d^{t}} ftRdt f s ∈ R d s f^s \in R^{d^{s}} fsRds,其中 d t d^{t} dt d s d^{s} ds在模型结构不同时可能不同, z t z^{t} zt z s z^{s} zs分别是老师和学生模型的logits预测。
p i ( z ; r ) = S o f t m a x ( z ; r ) = e z i r ∑ k = 1 K e z k r p_i(z;r)=Softmax(z;r)=\frac{e^{\frac{z_i}{r}}}{\sum_{k=1}^{K}{e^\frac{z_k}{r}}} pi(z;r)=Softmax(z;r)=k=1Kerzkerzi
上式初始温度 r = 1 r=1 r=1,随着 r r r的逐渐增大,softmax的output probability distribution越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。
L K D ( p s , p t ) = 1 N ∑ i = 1 N K L ( p s , p t ) L_{KD}(p^s,p^t)=\frac{1}{N}\sum_{i=1}^{N}{}KL(p^s,p^t) LKD(ps,pt)=N1i=1NKL(ps,pt)
在vanilla KD中,学生模型的损失表示为:
L = L C E ( p s , y ) + λ L K D ( p s , p t ) L = L_{CE}(p^s,y) + \lambda L_{KD}(p^s,p^t) L=LCE(ps,y)+λLKD(ps,pt)

Attributed Context Graph Construction

输入batch组图片到老师和学生模型得到特征表示 f t f^t ft f s f^s fs以及预测概率 p t p^t pt p s p^s ps。接着构建两个属性图 G t = { A t , F t } G^t=\{ A^t, F^t \} Gt={ At,Ft} G s = { A s , F s } G^s=\{ A^s, F^s \} Gs={ As,Fs}, 其中 F t ∈ R N × d t F^t \in R^{N \times d^t} FtRN×dt, F s ∈ R N × d s F^s \in R^{N \times d^s} FsRN×ds是图中节点的属性。 A t , A s A^t, A^s At,As基于 p t , p s p^t, p^s pt,ps得到的
A t = ϕ ( p t ) , A s = ϕ ( p s ) A^t=\phi(p^t), A^s=\phi(p^s) At=ϕ(pt),As=ϕ(ps)
其中 ϕ ( . ) \phi(.) ϕ(.)是基于KNN的图重构函数(不是很懂这个图是怎么构建出来的)。 G t G^t Gt是fixed,相比于全连接的graph,KNN的graph可以滤除不相关的样本对。插播KNN学习(关于KNN的学习基本按照一文搞懂k近邻(k-NN)算法(一)和Python—KNN分类算法(详解)来讲解的)
KNN又叫K Nearest Neighbors,即通过与待预测节点的K个最近节点来预测当前节点。如图所示:Distilling Holistic Knowledge with Graph Neural Networks论文解读_第5张图片
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第6张图片
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第7张图片
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第8张图片
附上论文knn_graph部分源代码和dgl官网代码Source code for dgl.transform:

def cos_distance_softmax(x):
    soft = F.softmax(x, dim=2)
    w = soft.norm(p=2, dim=2, keepdim=True)
    # L2范数
    print(B.swapaxes(soft, -1, -2))  # 将soft转置
    return 1 - soft @ B.swapaxes(soft, -1, -2) / (w @ B.swapaxes(w, -1, -2)).clamp(min=eps)  # soft * soft^{T}

def knn_graph(x, k):
    if B.ndim(x) == 2:
        x = B.unsqueeze(x, 0)
    n_samples, n_points, _ = B.shape(x)

    dist = cos_distance_softmax(x)  # 这里不太清楚为什么要用这个distance

    fil = 1 - torch.eye(n_points, n_points)
    dist = dist * B.unsqueeze(fil, 0).cuda()
    dist = dist - B.unsqueeze(torch.eye(n_points, n_points), 0).cuda()

    k_indices = B.argtopk(dist, k, 2, descending=False)

    dst = B.copy_to(k_indices, B.cpu())
    src = B.zeros_like(dst) + B.reshape(B.arange(0, n_points), (1, -1, 1))

    per_sample_offset = B.reshape(B.arange(0, n_samples) * n_points, (-1, 1, 1))
    dst += per_sample_offset
    src += per_sample_offset
    dst = B.reshape(dst, (-1,))
    src = B.reshape(src, (-1,))
    adj = sparse.csr_matrix((B.asnumpy(B.zeros_like(dst) + 1), (B.asnumpy(dst), B.asnumpy(src))))

    g = DGLGraph(adj, readonly=True)
    return g

Holistic Knowledge Distillation

用Topology Adaptive Graph Convolution Network (TAGCN)提取 G t G^t Gt, G s G^s Gs的holistic knowledge,用 H t ∈ R N × g t H^t \in R^{N \times g^{t}} HtRN×gt H s ∈ R N × g s H^s \in R^{N \times g^{s}} HsRN×gs
H t = ∑ l = 0 L ( D t − 1 / 2 A t D t − 1 / 2 ) l F t θ l t H^t = \sum_{l=0}^{L}{(D_t^{-1/2}A^tD_t^{-1/2})^lF^t\theta_l^t} Ht=l=0L(Dt1/2AtDt1/2)lFtθlt H s = ∑ l = 0 L ( D s − 1 / 2 A s D s − 1 / 2 ) l F s θ l s H^s = \sum_{l=0}^{L}{(D_s^{-1/2}A^sD_s^{-1/2})^lF^s\theta_l^s} Hs=l=0L(Ds1/2AsDs1/2)lFsθls
其中 g t g^t gt, g s g^s gs是图表示的维度, D t = ∑ j A i j t D_t=\sum_j{A_{ij}^t} Dt=jAijt是教师模型的对角线度矩阵, θ l t \theta_l^t θlt, θ l s \theta_l^s θls是可学习的权重。
使用互信息来蒸馏学生模型,使其最大化 H t H^t Ht H s H^s Hs之间的互信息。
L H O L W s , θ t , θ s = − I ( H t , H s ) \underset {W^s,\theta^t,\theta^s}{L_{HOL}} = -I(H^t, H^s) Ws,θt,θsLHOL=I(Ht,Hs)其中 I ( H t , H s ) I(H^t, H^s) I(Ht,Hs)用InfoNCE estimator来计算
I ( H t , H s ) ≥ E [ 1 N ∑ i = 1 N l o g e f ( h i t , h i s ) 1 N ∑ j = 1 N e f ( h i t , h i s ) ] I(H^t, H^s) \geq E[\frac{1}{N}\sum_{i=1}^N{log\frac{e^{f(h_i^t, h_i^s)}}{\frac{1}{N}\sum_{j=1}^N{e^{f(h_i^t, h_i^s)}}}}] I(Ht,Hs)E[N1i=1NlogN1j=1Nef(hit,his)ef(hit,his)] f ( . ) f(.) f(.)是余弦相似性, h i t h^t_i hit h i s h^s_i his是实例i由老师模型和学生模型分别学到的表示。
最终holistic知识蒸馏的目标函数是 L = L C E + β L H O L L=L_{CE}+\beta L_{HOL} L=LCE+βLHOL
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第9张图片
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第10张图片

"""Torch Module for Topology Adaptive Graph Convolutional layer"""
import torch as th
from torch import nn

from .... import function as fn

class TAGConv(nn.Module):
    def __init__(self,
        super(TAGConv, self).__init__()
        self._in_feats = in_feats
        self._out_feats = out_feats
        self._k = k
        self._activation = activation
        self.lin = nn.Linear(in_feats * (self._k + 1), out_feats, bias=bias)


    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.lin.weight, gain=gain)

    def forward(self, graph, feat):
        with graph.local_scope():
            assert graph.is_homogeneous, 'Graph is not homogeneous'

            norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
            shp = norm.shape + (1,) * (feat.dim() - 1)
            norm = th.reshape(norm, shp).to(feat.device)  # 貌似就做了个转置?

            #D-1/2 A D -1/2 X
            fstack = [feat]  # 后面说实话没怎么懂
            for _ in range(self._k):

                rst = fstack[-1] * norm
                graph.ndata['h'] = rst

                graph.update_all(fn.copy_src(src='h', out='m'),
                                 fn.sum(msg='m', out='h'))
                rst = graph.ndata['h']  # 单个节点的特征
                rst = rst * norm

            rst = self.lin(th.cat(fstack, dim=-1))

            if self._activation is not None:
                rst = self._activation(rst)

            return rst

Distilling Holistic Knowledge with Graph Neural Networks论文解读_第11张图片

Efficient Training

由于InfoNCE estimator需要对数据集中每个样本作为负样本计算,对于大数据集成本太高,因此文章使用Memory Bank strategy来储存。由于文章对mini-batch的样本进行随机采样(吧啦吧啦看不懂。。。不敢乱翻译,贴原文,如果有人看懂了请评论区踢我一脚)
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第12张图片
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第13张图片
L H O L = ∑ i = 1 N l o g e f ( h i t , h i s ) e f ( h i t , h i s ) + ∑ j = 1 , j ≠ i N e f ( h i t , f j s ) + l o g e f ( h i s , h i t ) e f ( h i s , h i t ) + ∑ j = 1 , j ≠ i N e f ( h i s , f j t ) L_{HOL}=\sum_{i=1}^N{log\frac{e^{f(h_i^t, h_i^s)}}{e^{f(h_i^t, h_i^s)}+\sum_{j=1,j \neq i}^N{e^{f(h_i^t, f_j^s)}}}}+log\frac{e^{f(h_i^s, h_i^t)}}{e^{f(h_i^s, h_i^t)}+\sum_{j=1,j \neq i}^N{e^{f(h_i^s, f_j^t)}}} LHOL=i=1Nlogef(hit,his)+j=1,j=iNef(hit,fjs)ef(hit,his)+logef(his,hit)+j=1,j=iNef(his,fjt)ef(his,hit)
Distilling Holistic Knowledge with Graph Neural Networks论文解读_第14张图片


Distilling Holistic Knowledge with Graph Neural Networks论文解读_第15张图片


