论文理解 Linkage Based Face Clustering via Graph Convolution Network

论文理解 Linkage Based Face Clustering via Graph Convolution Network

  • 背景
  • 要解决的问题
  • 基于GCN的人脸图像聚类
    • 图卷积层
    • 节点合并
    • KNN搜索
  • MxNet复现
    • GCN Layer
    • 其他

背景

其实是利用GCN进行人脸图像聚类
论文:Linkage Based Face Clustering via Graph Convolution Network
作者提供的代码:GCN Clustering

要解决的问题

如何在无ID且未知多少类的情况下,对人脸进行聚类
论文本身提出了将聚类问题看作是节点连接的预测问题的观点,即如果两张人脸图像属于同一个ID,则这两张人脸图像之间就存在连接;这里会用到图卷积网络,预习请点击链接

基于GCN的人脸图像聚类

简单点说,作者通过构造了一个GCN网络去进行预测,具体见下图
论文理解 Linkage Based Face Clustering via Graph Convolution Network_第1张图片
以上,为了实现这个GCN,你还需要准备以下东西:

  • 人脸特征提取模型
  • KNN搜索方法
  • 人脸识别用的数据库

流程上来说,作者将聚类过程划分为了对多个子图(SubGraphs)进行连接预测,然后再将预测结果链接在一起;以中心节点以及其连接(K-Hop)构成了作者论文中提到的Instance Pivot Subgraph(IPS),然后使用GCN预测其他节点是否应该与中心节点相连。

图卷积层

在GCN中主要是配合拉普拉斯矩阵对图结构进行处理,论文中定义的GCN层如下所示
Y = σ ( [ X ∣ ∣ G X ] W ) Y=\sigma([X||GX]W) Y=σ([XGX]W)
其中, X X X为特征矩阵,激活函数 σ \sigma σ使用的是ReLU, G = Λ − 1 2 A Λ − 1 2 G=\Lambda^{-\frac{1}{2}}A\Lambda^{-\frac{1}{2}} G=Λ21AΛ21,符号 ∣ ∣ || 是Concate操作; A A A是图的邻接矩阵, Λ i i = ∑ j A i j \Lambda_{ii}=\sum_{j}{A_{ij}} Λii=jAij,对这里的 G G G翻译一下,就是对邻接矩阵的每行进行了归一化,目的是为了使 G X GX GX的尺度不变;
这里作者对预习中所提到的使用拉普拉斯矩阵进行图卷积进行了一些轻微(点都不轻微好么!!!)的改动。对比Symmetric normalized Laplacian定义(为了方便理解,以下统一了符号):
L s y m = Λ − 1 2 ( Λ − A ) Λ − 1 2 L^{sym}=\Lambda^{-\frac{1}{2}}(\Lambda-A)\Lambda^{-\frac{1}{2}} Lsym=Λ21(ΛA)Λ21
可以知道 L s y m = I − Λ − 1 2 A Λ − 1 2 L^{sym}=I-\Lambda^{-\frac{1}{2}}A\Lambda^{-\frac{1}{2}} Lsym=IΛ21AΛ21,通过Renormalization Trick令 A ~ = I + A \widetilde{A}=I+A A =I+A Λ ~ i i = ∑ j A ~ i j \widetilde{\Lambda}_{ii}=\sum_{j}{\widetilde{A}_{ij}} Λ ii=jA ij,则有图卷积层的定义:
Y = σ ( Λ ~ − 1 2 A ~ Λ ~ − 1 2 X W ) Y=\sigma(\widetilde{\Lambda}^{-\frac{1}{2}}\widetilde{A}\widetilde{\Lambda}^{-\frac{1}{2}}XW) Y=σ(Λ 21A Λ 21XW)
那么拉普拉斯矩阵在图卷积层中干了什么呢?直白点说,就是将自节点和邻接节点做了加权平均;那么作者做了啥改动呢?将邻接节点做了平均并和自节点连接起来。

节点合并

这部分其实作者并没有在论问题提太多,不过分析作者提供的代码,应该使用了几种方式进行尝试,包括直接使用固定阈值连接节点后再使用宽度优先搜素得到聚类结果,不过作者给出的代码在合并阶段也采用了一些技巧,比如使用可变阈值以及最大合并数来防止聚类结果中某一类出现过大的聚类(我在使用Chinese Whisper算法进行聚类时,若直接使用固定阈值,会出现将不同人聚类到一类,然后导致这类中的样本数量占比超过总样本数的50%以上,可以认为对算法来说,这一类样本属于Hard Sample)

KNN搜索

过分的是,作者没有提是如何进行KNN搜索生成图的邻接节点的,代码中说可以使用任意方法。嗯,当你看到巨大的数据集以及龟速的搜索速度的时候,我决定使用Facebook的FAISS库来加速了

MxNet复现

由于部署需要,我用mxnet复现了作者的算法 ,不是我不喜欢pytorch,真的。具体请先参考我的Azure Research目录下的GraphGCN ,我不是在给微软打广告,真的。因为我的代码被我自己推翻了好几次,所以这里直接放上GraphGCN的主要实现部分,仅供参考

class GCN(gluon.HybridBlock):
    def __init__(self):
        super(GCN, self).__init__()
        with self.name_scope():
            self.bn = nn.BatchNorm(in_channels=512, axis=2)
            self.conv1 = GraphConv(512,512)
            self.conv2 = GraphConv(512,512)
            self.conv3 = GraphConv(512,256)
            self.conv4 = GraphConv(256,256)

            self.conv5 = nn.Dense(256)
            self.prelu = nn.PReLU()
            self.conv6 = nn.Dense(1, activation='sigmoid')
        pass

    def hybrid_forward(self, F, x, A):
        # BND
        x = self.bn(x)
        x = self.conv1(x, A)
        x = self.conv2(x, A)
        x = self.conv3(x, A)
        x = self.conv4(x, A)

        x = F.reshape(x, (-3, 0))

        x = self.conv5(x)
        x = self.prelu(x)
        x = self.conv6(x)
        return x
    pass

GCN Layer

以下是MxNet中实现的GraphConv部分,其中参数A为邻接矩阵,x为特征矩阵

class GraphConv(gluon.HybridBlock):
    def __init__(self, in_channels, channels):
        super(GraphConv, self).__init__()
        with self.name_scope():
            self.weight = self.params.get('weight',shape=(channels, 2*in_channels), 
                                                   init='xavier', dtype='float32', allow_deferred_init=True)
            self.bias = self.params.get('bias', shape=channels, init='zeros', allow_deferred_init=True)
            self.relu = nn.Activation('relu')
        pass
    
    def hybrid_forward(self, F, x, A, weight, bias):
        f = F.concat(x, F.batch_dot(A, x), dim=2)
        y = F.FullyConnected(data=f, weight=weight, bias=bias, 
                             num_hidden=self.weight.shape[0], flatten=False, no_bias=False) # BNDxDF=BNF
        z = self.relu(y)
        return z

其他

数据的预处理部分和后处理我是用自己的库搭建的,具体使用到的模块包括utils里的fast_search,numpy等,后处理部分目前全部封装到了face里的clusterer中,具体实现是放在ChaosMX中的gcn.cpp文件中,暂时先这样吧,懒了~谢谢观看!

你可能感兴趣的:(深度学习,人脸聚类,GCN)