torch_geometric实现GCN和LightGCN

torch_geometric实现GCN和LightGCN

  • 题记
  • demo示意图
  • GCN代码
  • LightGCN代码
  • 参考博文及感谢

题记

使用torch_geometric实现GCN和LightGCN,以后可能要用,做一下备份

demo示意图

torch_geometric实现GCN和LightGCN_第1张图片

GCN代码

X ′ = D ^ − 1 / 2 A ^ D ^ − 1 / 2 X Θ \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta} X=D^1/2A^D^1/2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, degree, add_remaining_self_loops
from torch_geometric.nn.inits import uniform, ones


torch.manual_seed(2023)
"""
 默认   \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
        \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},加自连接,按权重传递
        传递完成后归一化
"""


class BaseModel(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize=True, self_loops=True, bias=True, aggr='add', **kwargs):
        super(BaseModel, self).__init__(aggr=aggr, **kwargs)
        self.aggr = aggr
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.self_loops = self_loops
        self.normalize = normalize
        self.weight = Parameter(torch.Tensor(self.in_channels, out_channels))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.in_channels, self.weight)
        uniform(self.in_channels, self.bias)


    def forward(self, x, edge_index, edge_weight=None):
        if self.self_loops:
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value=1, num_nodes=x.size(0))
        x = torch.matmul(x, self.weight)  # 表示乘以一个可学习参数矩阵
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_weight=edge_weight)
        # propagate 依次调用self.message、self.aggregate和self.update方法(self.aggregate,略,无数值修改)

    def message(self, x_j, edge_index, size, edge_weight):
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        return norm.view(-1, 1) * x_j if norm is not None else x_j

        # norm = edge_weight   #将上面全部注释,即没有对邻接矩阵的归一化
        # return norm.view(-1, 1) * x_j if norm is not None else x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=-1)  # 按行进行归一化
        return aggr_out

    def __repr(self):
        return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


x = torch.tensor(
    [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0]])
GCN = BaseModel(in_channels=3, out_channels=3, self_loops=True, aggr="add")
edge_index = torch.tensor([[0, 1, 3, 3, 4, 0, 0, 1], [4, 0, 0, 1, 0, 1, 3, 3]])  # 2x8
edge_weight = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
edge_weight = edge_weight * 2

h = F.leaky_relu(GCN(x, edge_index, edge_weight=edge_weight))
print(h)

LightGCN代码

X ′ = D ^ − 1 / 2 A ^ D ^ − 1 / 2 X \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X} X=D^1/2A^D^1/2X

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, degree, add_remaining_self_loops
from torch_geometric.nn.inits import uniform, ones

torch.manual_seed(2023)
"""
 默认   \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
        \mathbf{\hat{D}}^{-1/2} \mathbf{X},不加自连接,按权重传递
        传递完成后不进行归一化
"""


class BaseModel(MessagePassing):
    def __init__(self, in_channels, out_channels, normalize=False, self_loops=False, aggr='add', **kwargs):
        super(BaseModel, self).__init__(aggr=aggr, **kwargs)
        self.aggr = aggr
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.self_loops = self_loops
        self.normalize = normalize

    def forward(self, x, edge_index, edge_weight=None):
        if self.self_loops:
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value=1, num_nodes=x.size(0))
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_weight=edge_weight)
        # propagate 依次调用self.message、self.aggregate和self.update方法(self.aggregate,略,无数值修改)

    def message(self, x_j, edge_index, size, edge_weight):
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        return norm.view(-1, 1) * x_j if norm is not None else x_j

        # norm = edge_weight   #将上面全部注释,即没有对邻接矩阵的归一化
        # return norm.view(-1, 1) * x_j if norm is not None else x_j

    def update(self, aggr_out):
        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=-1)  # 按行进行归一化
        return aggr_out

    def __repr(self):
        return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


x = torch.tensor(
    [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0], [4.0, 4.0, 4.0], [5.0, 5.0, 5.0]])
Lightgcn = BaseModel(in_channels=3, out_channels=3, self_loops=True, aggr="add")
edge_index = torch.tensor([[0, 1, 3, 3, 4, 0, 0, 1], [4, 0, 0, 1, 0, 1, 3, 3]])  # 2x8
edge_weight = torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
edge_weight = edge_weight * 2

h = Lightgcn(x, edge_index, edge_weight=edge_weight)
print(h)

参考博文及感谢

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 MMGCN论文开源代码
https://github.com/weiyinwei/MMGCN

你可能感兴趣的:(推荐系统,Pytorch,图神经网络,pytorch,人工智能,python)