本节我们通过代码来实现基于 自注意力的池化机制(Self-Attention Pooling)。这种方法的思路是通过图卷积从图中自适应地学习到节点的重要性。[0] 具体来说,使用第1章中定义的图卷积方式,可以为每个节点赋予一个重要性分数,如下式所示:
Z = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 X Θ a t t ) Z=σ(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}XΘ_{att}) Z=σ(D~−1/2A~D~−1/2XΘatt) 其中 σ σ σ 表示激活函数, A ~ \tilde{A} A~ 表示增加了自连接的邻接矩阵, X X X 表示节点的特征, Θ a t t ∈ R N × l Θ_{att}∈R^{N×l} Θatt∈RN×l 是权重参数,这也是自注意力池化层中唯一引入的参数。关于上述图卷积的实现,请参考1.6节中具体的代码实现,这里不再赘述。
根据节点重要度分数和拓扑结构可以进行池化操作,如下式所示:
i = top-rank ( z , k N ) \boldsymbol{i}=\text{top-rank}(\boldsymbol{z},kN) i=top-rank(z,kN)舍弃掉不太重要的节点,对邻接矩阵和节点特征进行更新,得到池化结果。首先来看如何根据上式实现节点的选择。代码片段如代码清单4-1所示:
def top_rank(attention_score, graph_indicator, keep_ratio):
"""基于给定的attention_score, 对每个图进行pooling操作.
为了直观体现pooling过程,我们将每个图单独进行池化,最后再将它们级联起来进行下一步计算
Arguments:
----------
attention_score:torch.Tensor
使用GCN计算出的注意力分数,Z = GCN(A, X)
graph_indicator:torch.Tensor
指示每个节点属于哪个图
keep_ratio: float
要保留的节点比例,保留的节点数量为int(N * keep_ratio)
"""
# TODO: 确认是否是有序的, 必须是有序的
graph_id_list = list(set(graph_indicator.cpu().numpy()))
mask = attention_score.new_empty((0,), dtype=torch.bool)
for graph_id in graph_id_list:
graph_attn_score = attention_score[graph_indicator == graph_id]
graph_node_num = len(graph_attn_score)
graph_mask = attention_score.new_zeros((graph_node_num,),
dtype=torch.bool)
keep_graph_node_num = int(keep_ratio * graph_node_num)
_, sorted_index = graph_attn_score.sort(descending=True)
graph_mask[sorted_index[:keep_graph_node_num]] = True
mask = torch.cat((mask, graph_mask))
return mask
函数top_rank
接收3个参数,一是使用GCN得到的节点重要度分数attention_score
;二是知识每个节点属于哪个图的参数graph_indicator
,这里我们将多个需要分类的图放在一起进行批处理,以提高运算速度,graph_indicator
里面包含的数据为 [ 0 , 0 , … , 0 , 1 , 1 , … , 1 , 2 , 2 , … , 2 , … ] [0,0,…,0,1,1,…,1,2,2,…,2,…] [0,0,…,0,1,1,…,1,2,2,…,2,…] 。需要注意的是,graph_indicator
的标识值需要进行升序排列,同时属于同一个图的节点需要连续排列在一起;三是超参数keep_ratio
,表示每次池化需要保留的节点比例,这是针对单个图而言的,不是整个批处理中所有的数据。实现逻辑上根据graph_indicator
依次遍历每个图,取出该图对应的注意力分数,并进行排序得到要保留的节点索引,将这些位置的索引设置为True
,得到每个子图节点的掩码向量。将所有图的掩码拼接在一起得到批处理中所有节点的掩码,作为函数的返回值。
接下来,根据得到的节点掩码对图结构和特征进行更新。图结构的更新是根据掩码向量对邻接矩阵进行索引,得到保留节点之间的邻接矩阵,再进行归一化,作为后续GCN层的输入。因此我们定义两个功能函数normalization(adjacency)
和filter_adjacency(adjacency, mask)
。其中normalization(adjacency)
接收一个scipy.sparse.csr_matrix
,对它进行规范化并转换为torch.sparse.FloatTensor
。另一个函数filter_adjacency(adjacency, mask)
接收两个参数,一个是池化之前的邻接矩阵adjacency,它的类型为torch.sparse.FloatTensor
,另一个函数top_rank
输出的节点的掩码mask。为了利用scipy.sparse提供的索引切片,这里将池化之前的的adjacency转换为scipy.sparse.csr_matrix
,然后通过掩码mask
进行切片,得到池化后的节点之间的邻接关系,然后再使用函数normalization
进行规范化,作为下一层图卷积的输入。如代码清单4-2所示:
def normalization(adjacency):
"""计算 L=D^-0.5 * (A+I) * D^-0.5,
Args:
adjacency: sp.csr_matrix.
Returns:
归一化后的邻接矩阵,类型为 torch.sparse.FloatTensor
"""
adjacency += sp.eye(adjacency.shape[0]) # 增加自连接
degree = np.array(adjacency.sum(1))
d_hat = sp.diags(np.power(degree, -0.5).flatten())
L = d_hat.dot(adjacency).dot(d_hat).tocoo()
# 转换为 torch.sparse.FloatTensor
indices = torch.from_numpy(np.asarray([L.row, L.col])).long()
values = torch.from_numpy(L.data.astype(np.float32))
tensor_adjacency = torch.sparse.FloatTensor(indices, values, L.shape)
return tensor_adjacency
def filter_adjacency(adjacency, mask):
"""根据掩码mask对图结构进行更新
Args:
adjacency: torch.sparse.FloatTensor, 池化之前的邻接矩阵
mask: torch.Tensor(dtype=torch.bool), 节点掩码向量
Returns:
torch.sparse.FloatTensor, 池化之后归一化邻接矩阵
"""
device = adjacency.device
mask = mask.cpu().numpy()
indices = adjacency.coalesce().indices().cpu().numpy()
num_nodes = adjacency.size(0)
row, col = indices
maskout_self_loop = row != col
row = row[maskout_self_loop]
col = col[maskout_self_loop]
sparse_adjacency = sp.csr_matrix((np.ones(len(row)), (row, col)),
shape=(num_nodes, num_nodes), dtype=np.float32)
filtered_adjacency = sparse_adjacency[mask, :][:, mask]
return normalization(filtered_adjacency).to(device)
利用上面介绍的这些功能函数,就可以实现 自注意力层 ,该层的输出为池化之后的特征、节点属于哪个子图的表示以及规范化的邻接矩阵。如代码清单4-3所示:
class SelfAttentionPooling(nn.Module):
def __init__(self, input_dim, keep_ratio, activation=torch.tanh):
super(SelfAttentionPooling, self).__init__()
self.input_dim = input_dim
self.keep_ratio = keep_ratio
self.activation = activation
self.attn_gcn = GraphConvolution(input_dim, 1)
def forward(self, adjacency, input_feature, graph_indicator):
attn_score = self.attn_gcn(adjacency, input_feature).squeeze()
attn_score = self.activation(attn_score)
mask = top_rank(attn_score, graph_indicator, self.keep_ratio)
hidden = input_feature[mask] * attn_score[mask].view(-1, 1)
mask_graph_indicator = graph_indicator[mask]
mask_adjacency = filter_adjacency(adjacency, mask)
return hidden, mask_graph_indicator, mask_adjacency
要进行图分类,还需要全局的池化操作,它将节点数不同的图降维到同一纬度。常见的全局池化方式包括取最大值或均值。下面是这两种方式的实现代码,如代码清单3-4所示:
import torch_scatter
def global_max_pool(x, graph_indicator):
num = graph_indicator.max().item() + 1
return torch_scatter.scatter_max(x, graph_indicator, dim=0, dim_size=num)[0]
def global_avg_pool(x, graph_indicator):
num = graph_indicator.max().item() + 1
return torch_scatter.scatter_mean(x, graph_indicator, dim=0, dim_size=num)
这里我们使用包torch_scatter
来简化实现的过程,其中用到的两个函数scatter_mean
和scatter_max
的原理如图4-6所示。
至此,我们就可以定义图分类的模型了。接下来我们定义如图4-7所示的两套SADPool模型,其中 a 图禁用了一个池化层,这套模型称为 SAGPool_g ,“g” 代表 global ,如代码清单4-5的实现;b 图使用了多个池化层,这套模型称为 SAGPool_h,“h” 表示 hierarchical ,如代码清单4-6的实现。在论文的实验部分,可以发现SAGPool_g比较适合小图分类,SAGPool_h更适合大图分类。
class ModelA(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes=2):
"""图分类模型结构A
Args:
----
input_dim: int, 输入特征的维度
hidden_dim: int, 隐藏层单元数
num_classes: 分类类别数 (default: 2)
"""
super(ModelA, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.gcn1 = GraphConvolution(input_dim, hidden_dim)
self.gcn2 = GraphConvolution(hidden_dim, hidden_dim)
self.gcn3 = GraphConvolution(hidden_dim, hidden_dim)
self.pool = SelfAttentionPooling(hidden_dim * 3, 0.5)
self.fc1 = nn.Linear(hidden_dim * 3 * 2, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
self.fc3 = nn.Linear(hidden_dim // 2, num_classes)
def forward(self, adjacency, input_feature, graph_indicator):
gcn1 = F.relu(self.gcn1(adjacency, input_feature))
gcn2 = F.relu(self.gcn2(adjacency, gcn1))
gcn3 = F.relu(self.gcn3(adjacency, gcn2))
gcn_feature = torch.cat((gcn1, gcn2, gcn3), dim=1)
pool, pool_graph_indicator, pool_adjacency = self.pool(adjacency, gcn_feature,
graph_indicator)
readout = torch.cat((global_avg_pool(pool, pool_graph_indicator),
global_max_pool(pool, pool_graph_indicator)), dim=1)
fc1 = F.relu(self.fc1(readout))
fc2 = F.relu(self.fc2(fc1))
logits = self.fc3(fc2)
return logits
模型SAGPool_h实现如代码清单4-6所示。
class ModelB(nn.Module):
def __init__(self, input_dim, hidden_dim, num_classes=2):
"""图分类模型结构
Args:
-----
input_dim: int, 输入特征的维度
hidden_dim: int, 隐藏层单元数
num_classes: int, 分类类别数 (default: 2)
"""
super(ModelB, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_classes = num_classes
self.gcn1 = GraphConvolution(input_dim, hidden_dim)
self.pool1 = SelfAttentionPooling(hidden_dim, 0.5)
self.gcn2 = GraphConvolution(hidden_dim, hidden_dim)
self.pool2 = SelfAttentionPooling(hidden_dim, 0.5)
self.gcn3 = GraphConvolution(hidden_dim, hidden_dim)
self.pool3 = SelfAttentionPooling(hidden_dim, 0.5)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_classes))
def forward(self, adjacency, input_feature, graph_indicator):
gcn1 = F.relu(self.gcn1(adjacency, input_feature))
pool1, pool1_graph_indicator, pool1_adjacency = \
self.pool1(adjacency, gcn1, graph_indicator)
global_pool1 = torch.cat(
[global_avg_pool(pool1, pool1_graph_indicator),
global_max_pool(pool1, pool1_graph_indicator)],
dim=1)
gcn2 = F.relu(self.gcn2(pool1_adjacency, pool1))
pool2, pool2_graph_indicator, pool2_adjacency = \
self.pool2(pool1_adjacency, gcn2, pool1_graph_indicator)
global_pool2 = torch.cat(
[global_avg_pool(pool2, pool2_graph_indicator),
global_max_pool(pool2, pool2_graph_indicator)],
dim=1)
gcn3 = F.relu(self.gcn3(pool2_adjacency, pool2))
pool3, pool3_graph_indicator, pool3_adjacency = \
self.pool3(pool2_adjacency, gcn3, pool2_graph_indicator)
global_pool3 = torch.cat(
[global_avg_pool(pool3, pool3_graph_indicator),
global_max_pool(pool3, pool3_graph_indicator)],
dim=1)
readout = global_pool1 + global_pool2 + global_pool3
logits = self.mlp(readout)
return logits
参考文献
[0] 刘忠雨, 李彦霖, 周洋.《深入浅出图神经网络: GNN原理解析》.机械工业出版社.