笔记:Pytorch-geometric: GAT代码超详细解读 | source node | target node | source_to_target

文章目录

  • 笔记:Pytorch-geometric: GAT代码超详细解读 | source node | target node | source_to_target
    • 前言
    • Torch geometric官方的GAT实现
    • 源码解读
      • 输入图
      • __init__部分
      • forward部分
      • 重构message方法
    • 总结

笔记:Pytorch-geometric: GAT代码超详细解读 | source node | target node | source_to_target

知识分享求点赞QAQ,能力有限,如有错误欢迎诸位大佬指正。

不想读源码又想了解torch-geometric库利用message-passing实现GAT的机理,找遍博文也没有满意的,看了官方的文档也不能完全理解(大概还是自己理解能力不太行),于是有了这篇源码解读。

前言

  • 什么是GAT?是Graph Attention Networks,图注意网络,具体参考其他人的文章
    笔记:Pytorch-geometric: GAT代码超详细解读 | source node | target node | source_to_target_第1张图片
  • 什么是Pytorch-geometric?是目前常用的实现图神经网络方法的依赖库,本文详述的GAT的torch实现方法,可见官方文档torch-geometric GAT
  • 什么是message passing?是torch geometric为了方便用户构建图神经网络实现的类,GAT的实现即继承了message passing类

Torch geometric官方的GAT实现

  • 官网给出的公式:笔记:Pytorch-geometric: GAT代码超详细解读 | source node | target node | source_to_target_第2张图片
    其中 Θ \Theta Θ是参数, α i j \alpha_{ij} αij是注意力系数,其中说明:i代表target node,j代表source node。从公式或者GAT的示意图很容易得出消息的流向是从source node到target node。
  • 官方的GATConv源码:
class GATConv(MessagePassing):
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, heads: int = 1, concat: bool = True,
                 negative_slope: float = 0.2, dropout: float = 0.,
                 add_self_loops: bool = True, bias: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(GATConv, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops

        if isinstance(in_channels, int):
            self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
            self.lin_r = self.lin_l
        else:
            self.lin_l = Linear(in_channels[0], heads * out_channels, False)
            self.lin_r = Linear(in_channels[1], heads * out_channels, False)

        self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_r = Parameter(torch.Tensor(1, heads, out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.lin_l.weight)
        glorot(self.lin_r.weight)
        glorot(self.att_l)
        glorot(self.att_r)
        zeros(self.bias)

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                size: Size = None, return_attention_weights=None):
        H, C = self.heads, self.out_channels

        x_l: OptTensor = None
        x_r: OptTensor = None
        alpha_l: OptTensor = None
        alpha_r: OptTensor = None
        if isinstance(x, Tensor):
            assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
            x_l = x_r = self.lin_l(x).view(-1, H, C)
            alpha_l = (x_l * self.att_l).sum(dim=-1)
            alpha_r = (x_r * self.att_r).sum(dim=-1)
        else:
            x_l, x_r = x[0], x[1]
            assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
            x_l = self.lin_l(x_l).view(-1, H, C)
            alpha_l = (x_l * self.att_l).sum(dim=-1)
            if x_r is not None:
                x_r = self.lin_r(x_r).view(-1, H, C)
                alpha_r = (x_r * self.att_r).sum(dim=-1)

        assert x_l is not None
        assert alpha_l is not None

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                num_nodes = x_l.size(0)
                if x_r is not None:
                    num_nodes = min(num_nodes, x_r.size(0))
                if size is not None:
                    num_nodes = min(size[0], size[1])
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
            elif isinstance(edge_index, SparseTensor):
                edge_index = set_diag(edge_index)

        # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
        out = self.propagate(edge_index, x=(x_l, x_r),
                             alpha=(alpha_l, alpha_r), size=size)

        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out += self.bias

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
                index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:
        alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha.unsqueeze(-1)

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)

源码解读

输入图

为了方便的解读源码 ,我们建立一个简单的图用于输入,图中包含三个标号0,1,2的节点,节点特征是二维的。
笔记:Pytorch-geometric: GAT代码超详细解读 | source node | target node | source_to_target_第3张图片
建立图代码如下

import torch
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
x=torch.tensor([[1.,2],[2,3],[1,3]])
edge_index=torch.LongTensor([[0,0],[1,2]])
edge_index = to_undirected(edge_index)#处理成无向图
graph = Data(x=x,edge_index=edge_index)

__init__部分

self.in_channels = in_channels#输入的特征维数,本例子为2
self.out_channels = out_channels#输出的特征维数,本例子简单起见取1
self.heads = heads#multi-head策略,本例子简单起见取1
self.concat = concat#表示multi-head输出后的多个特征向量的处理方法是否需要拼接,此处heads=1,故不考虑
self.negative_slope = negative_slope#采用leakyRELU的激活函数,x的负半平面斜率系数
self.dropout = dropout#过拟合的考虑,不关心
self.add_self_loops = add_self_loops#GAT要求加入自环,即每个节点要与自身连接
if isinstance(in_channels, int):
    self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
    self.lin_r = self.lin_l
    #相当于公式中的Theta,此处和接下来的else内容
    #是因为其GAT实现了Theta1和Theta2分别用于source node和target node
else:
    self.lin_l = Linear(in_channels[0], heads * out_channels, False)
    self.lin_r = Linear(in_channels[1], heads * out_channels, False)
self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
self.att_r = Parameter(torch.Tensor(1, heads, out_channels))
#att_l和att_r对应公式中的a^T,l和r也是分别用于source node和target node

这一部分非常简单,见注释。注意Message passing有可选参数flow,可以选择为source_to_target或者是target_to_source。很明显GAT是前者,且与默认值相同,不做修改。

forward部分

def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                size: Size = None, return_attention_weights=None):
       
        H, C = self.heads, self.out_channels

输入特征向量矩阵xedge_index
此处输入的

x=[
[1,2],<----0号节点特征
[2,3],<----1号节点特征
[1,3]<----2号节点特征
]
edge_index = [
[0,0,1,2],<----source node
[1,2,0,0]<-----target node
]

注意edge_index出现了变化,原因是建图中to_undirected的操作

  x_l: OptTensor = None
  x_r: OptTensor = None
  alpha_l: OptTensor = None
  alpha_r: OptTensor = None
  if isinstance(x, Tensor):
      assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
      x_l = x_r = self.lin_l(x).view(-1, H, C)
      alpha_l = (x_l * self.att_l).sum(dim=-1)
      alpha_r = (x_r * self.att_r).sum(dim=-1)
  else:
      x_l, x_r = x[0], x[1]
      assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
      x_l = self.lin_l(x_l).view(-1, H, C)
      alpha_l = (x_l * self.att_l).sum(dim=-1)
      if x_r is not None:
          x_r = self.lin_r(x_r).view(-1, H, C)
          alpha_r = (x_r * self.att_r).sum(dim=-1)

x_l,x_r分别计算的是左乘 Θ \Theta Θ后的向量值,这里再强调(因为后面很重要),l对应source node,r对应target node,i代表target node,j代表source node。
此外alpha_lx_lself.att_l点积之后的结果,对应 a l T Θ l x a^T_l\Theta_l x alTΘlx,同理alpha_r
我们假设二维到一维的映射是简单的相加(即 Θ \Theta Θ左乘就是相加),同时 a T a^T aT的作用是乘以0.5),那么此时的 x_l,x_r,alpha_l,alpha_r为:

x_l = x_r = [
[3],<----0号节点
[5],<----1号节点
[4]<----2号节点
]
alpha_l,alpha_r = [
[1.5],<----0号节点
[2.5],<----1号节点
[2]<----2号节点
]

if self.add_self_loops:
     if isinstance(edge_index, Tensor):
         num_nodes = x_l.size(0)
         if x_r is not None:
             num_nodes = min(num_nodes, x_r.size(0))
         if size is not None:
             num_nodes = min(size[0], size[1])
         edge_index, _ = remove_self_loops(edge_index)
         edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
     elif isinstance(edge_index, SparseTensor):
         edge_index = set_diag(edge_index)

接下来为edge_index加入自环,加入自环之后的edge_index变为:

edge_index = [
[0,0,1,2,0,1,2],<----source node
[1,2,0,0,0,1,2]<-----target node
]

 out = self.propagate(edge_index, x=(x_l, x_r),
                      alpha=(alpha_l, alpha_r), size=size)

调用Message passing的propagate的方法,这是一个集成方法,调用其会依次调用messageaggregateupdate方法。在source_to_target的方式下,message方法负责产生source node需要传出的信息,aggregate负责为target node收集来自source node的信息,一般是maxadd(default)等方法,GAT默认采用的是add方法,update用于更新表示。可见实现GAT最关键的是message方法的构造。
注意源码中调用propagate传入的参数会等价的传入messageaggregate中,这里传入的x是一个元胞,例如(x_l,x_r),元胞中第一位是用作source node信息使用的,第二位是用作target node信息使用的。

重构message方法

def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
                index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:
        alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha.unsqueeze(-1)

x_jalpha_j是source node的信息,index是与source node相连的target node的标号,ptr默认值是None,这里不考虑。这么说是不是非常的不明白?这里就需要数字举例了。
此时有:

edge_index = [
[0,0,1,2,0,1,2],<----source node
[1,2,0,0,0,1,2]<-----target node
]
x_l = x_r = [
[3],<----0号节点
[5],<----1号节点
[4]<----2号节点
]
alpha_l,alpha_r = [
[1.5],<----0号节点
[2.5],<----1号节点
[2]<----2号节点
]

传入message中的各变量为:

index=[1,2,0,0,0,1,2]<—就是edge_index的第二行,表示target node
x_j=[
[3],<—0号节点的x_l值
[3],<—0号节点的x_l值
[5],<—1号节点的x_l值
[4],<—2号节点的x_l值
[3],<—0号节点的x_l值
[5],<—1号节点的x_l值
[4]<—2号节点的x_l值
]<----顺序与edge_index第一行一致
alpha_j=[
[1.5],<—0号节点的alpha_l值
[1.5],<—0号节点的alpha_l值
[2.5],<—1号节点的alpha_l值
[2],<—2号节点的alpha_l值
[1.5],<—0号节点的alpha_l值
[2.5],<—1号节点的alpha_l值
[2]<—2号节点的alpha_l值
]<----顺序与edge_index第一行一致
alpha_i=[
[2.5],<—1号节点的alpha_r值
[2],<—2号节点的alpha_r值
[1.5],<—0号节点的alpha_r值
[1.5],<—0号节点的alpha_r值
[1.5],<—0号节点的alpha_r值
[2.5],<—1号节点的alpha_r值
[2]<—2号节点的alpha_r值
]<----顺序与edge_index第二行一致

这样就非常清晰明了了。剩下的就是说明其softmax的实现

alpha = softmax(alpha, index, ptr, size_i)

这里的alphaalpha_ialpha_j的和:

alpha=[[4],[3.5],[4],[3.5],[3],[5],[4]]

softmax函数先是对alpha的内容都取exp,得到exp_alpha

exp_alpha=[exp(4),exp(3.5),exp(4),exp(3.5),exp(3),exp(5),exp(4)]#简单起见省略了中间的小括号
index = [1,2,0,0,0,1,2]

最后的softmax函数是依赖exp_alphaindex共同得到输出out

out=[
exp(4)/(exp(4)+exp(5)),
exp(3.5)/(exp(4)+exp(3.5)),
exp(4)/(exp(3)+exp(4)+exp(3.5)),
exp(3.5)/(exp(3)+exp(4)+exp(3.5)),
exp(3)/(exp(3)+exp(4)+exp(3.5)),
exp(5)/(exp(4)+exp(5)),
exp(4)/(exp(4)+exp(3.5)),
]

到这一步了,我居然不知道怎么用文字解释indexexp_alpha产生out的方法…就看上面的公式找规律吧,很容易观察出来,大概就是按位寻找对应index中内容相同的,然后计算占比这样。
这样的out就是注意力系数了,到这里GAT的讲解也就结束了。

总结

应该各部分都很好理解,除了message部分,文中举例了数据,也列出了输入和输出,仔细观察应该不难弄明白。

你可能感兴趣的:(神经网络,数据挖掘,机器学习,深度学习,pytorch)