[GCN] 验证+解释 renormalization trick

文章目录

      • 简介
      • 示例的数学推导
      • 示例的程序验证
      • 对结果的一点理解
      • renormalization trick的用处
      • 附录:程序源代码

简介

Semi-Supervised Classification With Graph Convolutional Networks 这篇论文提出了一种对邻接矩阵

A A A 的归一化操作,论文中称为 renormalization trick

公式表达如下:
I N + D − 1 / 2 A D − 1 / 2 ⟶ D ~ − 1 / 2 A ~ D ~ − 1 / 2 I_N + D^{-1/2} A D^{-1/2} \longrightarrow \widetilde{D}^{-1/2} \widetilde{A} \widetilde{D}^{-1/2} IN+D1/2AD1/2D 1/2A D 1/2
其中:

  • A ~ = A + I N \widetilde{A} = A + I_N A =A+IN
  • D ~ i , i = ∑ j A ~ i , j \widetilde{D}_{i,i} = \sum_j \widetilde{A}_{i,j} D i,i=jA i,j

示例的数学推导

参考博客 Semi-Supervised Classification With Graph Convolutional Networks 的示例进行验证。
[GCN] 验证+解释 renormalization trick_第1张图片
[GCN] 验证+解释 renormalization trick_第2张图片
[GCN] 验证+解释 renormalization trick_第3张图片

上式对邻接矩阵进行了标准化,这个标准化称之为 random walk normalization。

然而,在实际中,动态特性更为重要,因此经常使用的是 renormalization(下面的公式有个错误,等号左边的 A ~ \widetilde{A} A 在论文中为 A ^ \widehat{A} A ):

[GCN] 验证+解释 renormalization trick_第4张图片
[GCN] 验证+解释 renormalization trick_第5张图片


示例的程序验证

I N + D − 1 / 2 A D − 1 / 2 ⟶ D ~ − 1 / 2 A ~ D ~ − 1 / 2 I_N + D^{-1/2} A D^{-1/2} \longrightarrow \widetilde{D}^{-1/2} \widetilde{A} \widetilde{D}^{-1/2} IN+D1/2AD1/2D 1/2A D 1/2

用程序验证上式:

  • I N I_N IN

    # [[1 0 0 0 0 0]
    #  [0 1 0 0 0 0]
    #  [0 0 1 0 0 0]
    #  [0 0 0 1 0 0]
    #  [0 0 0 0 1 0]
    #  [0 0 0 0 0 1]]
    
  • D D D

    # [[2 0 0 0 0 0]
    #  [0 3 0 0 0 0]
    #  [0 0 2 0 0 0]
    #  [0 0 0 3 0 0]
    #  [0 0 0 0 3 0]
    #  [0 0 0 0 0 1]]
    
  • A A A

    # [[0 1 0 0 1 0]
    #  [1 0 1 0 1 0]
    #  [0 1 0 1 0 0]
    #  [0 0 1 0 1 1]
    #  [1 1 0 1 0 0]
    #  [0 0 0 1 0 0]]
    
  • L = D − A L=D-A L=DA

    # [[ 2 -1  0  0 -1  0]
    #  [-1  3 -1  0 -1  0]
    #  [ 0 -1  2 -1  0  0]
    #  [ 0  0 -1  3 -1 -1]
    #  [-1 -1  0 -1  3  0]
    #  [ 0  0  0 -1  0  1]]
    
  • A ~ = A + I N \widetilde{A} =A+I_N A =A+IN

    # [[1, 1, 0, 0, 1, 0],
    #  [1, 1, 1, 0, 1, 0],
    #  [0, 1, 1, 1, 0, 0],
    #  [0, 0, 1, 1, 1, 1],
    #  [1, 1, 0, 1, 1, 0],
    #  [0, 0, 0, 1, 0, 1]]
    
  • D ~ = D + I N \widetilde{D}=D+I_N D =D+IN

    # [[3, 0, 0, 0, 0, 0],
    #  [0, 4, 0, 0, 0, 0],
    #  [0, 0, 3, 0, 0, 0],
    #  [0, 0, 0, 4, 0, 0],
    #  [0, 0, 0, 0, 4, 0],
    #  [0, 0, 0, 0, 0, 2]]
    
  • D ~ − 1 / 2 \widetilde{D}^{-1/2} D 1/2

    # [[0.57735027 0.         0.         0.         0.         0.        ]
    #  [0.         0.5        0.         0.         0.         0.        ]
    #  [0.         0.         0.57735027 0.         0.         0.        ]
    #  [0.         0.         0.         0.5        0.         0.        ]
    #  [0.         0.         0.         0.         0.5        0.        ]
    #  [0.         0.         0.         0.         0.         0.70710678]]
    
  • D − 1 / 2 D^{-1/2} D1/2

    # [[0.70710678 0.         0.         0.         0.         0.        ]
    #  [0.         0.57735027 0.         0.         0.         0.        ]
    #  [0.         0.         0.70710678 0.         0.         0.        ]
    #  [0.         0.         0.         0.57735027 0.         0.        ]
    #  [0.         0.         0.         0.         0.57735027 0.        ]
    #  [0.         0.         0.         0.         0.         1.        ]]
    

通过计算,得到以下结果:

  • I N + D − 1 / 2 A D − 1 / 2 I_N + D^{-1/2} A D^{-1/2} IN+D1/2AD1/2

    # [[1.         0.40824829 0.         0.         0.40824829 0.        ]
    #  [0.40824829 1.         0.40824829 0.         0.33333333 0.        ]
    #  [0.         0.40824829 1.         0.40824829 0.         0.        ]
    #  [0.         0.         0.40824829 1.         0.33333333 0.57735027]
    #  [0.40824829 0.33333333 0.         0.33333333 1.         0.        ]
    #  [0.         0.         0.         0.57735027 0.         1.        ]]
    
  • A ^ = D ~ − 1 / 2 A ~ D ~ − 1 / 2 \widehat{A}=\widetilde{D}^{-1/2} \widetilde{A} \widetilde{D}^{-1/2} A =D 1/2A D 1/2

    # [[0.33333333 0.28867513 0.         0.         0.28867513 0.        ]
    #  [0.28867513 0.25       0.28867513 0.         0.25       0.        ]
    #  [0.         0.28867513 0.33333333 0.28867513 0.         0.        ]
    #  [0.         0.         0.28867513 0.25       0.25       0.35355339]
    #  [0.28867513 0.25       0.         0.25       0.25       0.        ]
    #  [0.         0.         0.         0.35355339 0.         0.5       ]]
    

对结果的一点理解

D − 1 / 2 D^{-1/2} D1/2 D ~ − 1 / 2 \widetilde{D}^{-1/2} D 1/2 的值相差不大,其实就是 D ~ − 1 / 2 \widetilde{D}^{-1/2} D 1/2分母根号下的数比 D − 1 / 2 D^{-1/2} D1/2 大1(自环的引入)

所以我认为两种算法在非对角线上的归一化结果相差不大,重点是在对角线元素的归一化结果上。

第一种算法 I N + D − 1 / 2 A D − 1 / 2 I_N + D^{-1/2} A D^{-1/2} IN+D1/2AD1/2 的对角线元素是在邻接矩阵 A A A 归一化之后直接加上的,所以整个对角线元素均为1。

第二种算法 A ^ = D ~ − 1 / 2 A ~ D ~ − 1 / 2 \widehat{A}=\widetilde{D}^{-1/2} \widetilde{A} \widetilde{D}^{-1/2} A =D 1/2A D 1/2 的对角线元素是在归一化操作之前引入的,所以归一化操作就包括了对角线元素,而对角线元素的值便与度矩阵 D D D 有关。


renormalization trick的用处

作者在论文中给出了提出 renormalization trick 的原因:

I N + D − 1 / 2 A D − 1 / 2 I_N+D^{-1/2}AD^{-1/2} IN+D1/2AD1/2 has eigenvalues in the range [0, 2]. Repeated application of this operator can therefore lead to numerical instabilities and exploding/vanishing gradient when used in a deep neural network model.
To alleviate this problem, we introduce the following renormalization trick. I N + D − 1 / 2 A D − 1 / 2 − − > D ~ − 1 / 2 A ~ D ~ − 1 / 2 I_N+D^{-1/2}AD^{-1/2} -->\widetilde{D}^{-1/2}\widetilde{A}\widetilde{D}^{-1/2} IN+D1/2AD1/2>D 1/2A D 1/2, with A ~ = A + I N \widetilde{A}=A+I_N A =A+IN and D ~ i i = ∑ j A ~ i j \widetilde{D}_{ii}=\sum_j\widetilde{A}_{ij} D ii=jA ij.

翻译一下,引入renormalization trick I N + D − 1 / 2 A D − 1 / 2 − − > D ~ − 1 / 2 A ~ D ~ − 1 / 2 I_N+D^{-1/2}AD^{-1/2} -->\widetilde{D}^{-1/2}\widetilde{A}\widetilde{D}^{-1/2} IN+D1/2AD1/2>D 1/2A D 1/2,是为了避免重复使用 I N + D − 1 / 2 A D − 1 / 2 I_N+D^{-1/2}AD^{-1/2} IN+D1/2AD1/2的操作子给深度网络带来的数值不稳定和梯度弥散或爆炸。

因为 I N + D − 1 / 2 A D − 1 / 2 I_N+D^{-1/2}AD^{-1/2} IN+D1/2AD1/2的特征值在[0, 2]范围内,而 D ~ − 1 / 2 A ~ D ~ − 1 / 2 \widetilde{D}^{-1/2}\widetilde{A}\widetilde{D}^{-1/2} D 1/2A D 1/2的特征值在[-1, 1]范围内。

(个人猜测,未验证,因为 I N + D − 1 / 2 A D − 1 / 2 I_N+D^{-1/2}AD^{-1/2} IN+D1/2AD1/2可以看作 I N I_N IN D − 1 / 2 A D − 1 / 2 D^{-1/2}AD^{-1/2} D1/2AD1/2的叠加)。

换句话说,这个 renormalization trick 的引入,本来就不是一个恒等变换,这一点从上面的验证中就可以看出来。但其实,只要归一化实现了,一些细微的差别其实没有什么关系,归一化说白了其实就是一种数据缩放的方式而已。

相反,它对数值不稳定和梯度弥散或爆炸的规避是很有用的。


附录:程序源代码

import numpy as np

I=np.array([[1,0,0,0,0,0],[0,1,0,0,0,0],[0,0,1,0,0,0],[0,0,0,1,0,0],[0,0,0,0,1,0],[0,0,0,0,0,1]])
# [[1 0 0 0 0 0]
#  [0 1 0 0 0 0]
#  [0 0 1 0 0 0]
#  [0 0 0 1 0 0]
#  [0 0 0 0 1 0]
#  [0 0 0 0 0 1]]
D=np.array([[2,0,0,0,0,0],[0,3,0,0,0,0],[0,0,2,0,0,0],[0,0,0,3,0,0],[0,0,0,0,3,0],[0,0,0,0,0,1]])
# [[2 0 0 0 0 0]
#  [0 3 0 0 0 0]
#  [0 0 2 0 0 0]
#  [0 0 0 3 0 0]
#  [0 0 0 0 3 0]
#  [0 0 0 0 0 1]]
A=np.array([[0,1,0,0,1,0],[1,0,1,0,1,0],[0,1,0,1,0,0],[0,0,1,0,1,1],[1,1,0,1,0,0],[0,0,0,1,0,0]])
# [[0 1 0 0 1 0]
#  [1 0 1 0 1 0]
#  [0 1 0 1 0 0]
#  [0 0 1 0 1 1]
#  [1 1 0 1 0 0]
#  [0 0 0 1 0 0]]
L=D-A
# [[ 2 -1  0  0 -1  0]
#  [-1  3 -1  0 -1  0]
#  [ 0 -1  2 -1  0  0]
#  [ 0  0 -1  3 -1 -1]
#  [-1 -1  0 -1  3  0]
#  [ 0  0  0 -1  0  1]]
A_wave=A+I
# [[1, 1, 0, 0, 1, 0],
#  [1, 1, 1, 0, 1, 0],
#  [0, 1, 1, 1, 0, 0],
#  [0, 0, 1, 1, 1, 1],
#  [1, 1, 0, 1, 1, 0],
#  [0, 0, 0, 1, 0, 1]]
D_wave=D+I
# [[3, 0, 0, 0, 0, 0],
#  [0, 4, 0, 0, 0, 0],
#  [0, 0, 3, 0, 0, 0],
#  [0, 0, 0, 4, 0, 0],
#  [0, 0, 0, 0, 4, 0],
#  [0, 0, 0, 0, 0, 2]]

D_wave_sqrt=D_wave**-0.5
D_wave_sqrt[D_wave_sqrt>1]=0
# [[0.57735027 0.         0.         0.         0.         0.        ]
#  [0.         0.5        0.         0.         0.         0.        ]
#  [0.         0.         0.57735027 0.         0.         0.        ]
#  [0.         0.         0.         0.5        0.         0.        ]
#  [0.         0.         0.         0.         0.5        0.        ]
#  [0.         0.         0.         0.         0.         0.70710678]]

D_wave_sqrt=np.mat(D_wave_sqrt)
A_wave=np.mat(A_wave)
A_hat=D_wave_sqrt*A_wave*D_wave_sqrt
# [[0.33333333 0.28867513 0.         0.         0.28867513 0.        ]
#  [0.28867513 0.25       0.28867513 0.         0.25       0.        ]
#  [0.         0.28867513 0.33333333 0.28867513 0.         0.        ]
#  [0.         0.         0.28867513 0.25       0.25       0.35355339]
#  [0.28867513 0.25       0.         0.25       0.25       0.        ]
#  [0.         0.         0.         0.35355339 0.         0.5       ]]

D_sqrt=D**-0.5
D_sqrt[D_sqrt>1]=0
# [[0.70710678 0.         0.         0.         0.         0.        ]
#  [0.         0.57735027 0.         0.         0.         0.        ]
#  [0.         0.         0.70710678 0.         0.         0.        ]
#  [0.         0.         0.         0.57735027 0.         0.        ]
#  [0.         0.         0.         0.         0.57735027 0.        ]
#  [0.         0.         0.         0.         0.         1.        ]]

# I + D^{-1/2} A D^{-1/2}
D_sqrt=np.mat(D_sqrt)
A=np.mat(A)
I=np.mat(I)
before = I + D_sqrt*A*D_sqrt
print(before)
# [[1.         0.40824829 0.         0.         0.40824829 0.        ]
#  [0.40824829 1.         0.40824829 0.         0.33333333 0.        ]
#  [0.         0.40824829 1.         0.40824829 0.         0.        ]
#  [0.         0.         0.40824829 1.         0.33333333 0.57735027]
#  [0.40824829 0.33333333 0.         0.33333333 1.         0.        ]
#  [0.         0.         0.         0.57735027 0.         1.        ]]

after=D_wave_sqrt*A_wave*D_wave_sqrt
print(after)
# [[0.33333333 0.28867513 0.         0.         0.28867513 0.        ]
#  [0.28867513 0.25       0.28867513 0.         0.25       0.        ]
#  [0.         0.28867513 0.33333333 0.28867513 0.         0.        ]
#  [0.         0.         0.28867513 0.25       0.25       0.35355339]
#  [0.28867513 0.25       0.         0.25       0.25       0.        ]
#  [0.         0.         0.         0.35355339 0.         0.5       ]]
8867513 0.33333333 0.28867513 0.         0.        ]
#  [0.         0.         0.28867513 0.25       0.25       0.35355339]
#  [0.28867513 0.25       0.         0.25       0.25       0.        ]
#  [0.         0.         0.         0.35355339 0.         0.5       ]]

你可能感兴趣的:(paper,reading,图卷积神经网络,python,数据结构,gcn)