PyG是面向图数据的,它同时支持同构图(homogeneous graphs)和异构图(heterogeneous)。同构图指只包含一种类型的节点和边的图(下图左)。而异构图指包含两种及以上类型的节点和边的图(下图右)。
在PyG中,同构图被描述为torch_geometric.data.Data
类的实例,而异构图被描述为torch_geometric.data.HeteroData
的实例。
本文主要介绍PyG关于同构图的的相关操作,操作环境为:
pytorch = 1.10.1
cuda = 11.3
torch_geometric = 2.0.4
同构图是用Data
类是进行描述的,因此首先查看其初始化函数的参数列表:
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
edge_attr: OptTensor = None, y: OptTensor = None,
pos: OptTensor = None, **kwargs):
对应的参数说明为:
参数 | 说明 |
---|---|
x |
节点特征矩阵,shape为[num_nodes, num_node_features] ,Tensor 类型 |
edge_index |
边索引(边表),shape为[2, num_edges] ,在这个包含两行的数组中,第1行与第2行中对应索引位置的值分别表示一条边的源节点和目标节点,LongTensor 类型。 |
edge_attr |
边特征矩阵,shape为[num_edges, num_edges_featrues] ,Tensor 类型 |
y |
图级标签或节点级标签,Tensor 类型 |
pos |
节点的位置矩阵,shape为[num_nodes, num_dimensions] ,Tensor 类型 |
**kwargs |
用户自定义的额外属性,传入格式需为attr_name=attr_value |
Data
类的初始化函数中参数默认值都为None
,这意味着没有哪个参数是必要的,在实际使用时需要根据待构造图的实际情况来传入相应的属性。
在PyG中,对于一个Data
对象其包含众多属性和方法,这里列举一下常用的,更详细的请参见官网Data部分。
方法/属性 | 说明 |
---|---|
num_node_features /num_features |
图节点数特征(维度)数 |
num_edge_features |
图中边的特征(维度)数 |
keys |
图属性名列表 |
num_edges |
图边数 |
num_nodes |
图节点数 |
is_directed() /is_undirected() |
是否为有向图/无向图 |
is_cuda |
图是否存储在GPU上 |
has_self_loops() /contains_self_loops() |
图中节点是否包含自环 |
has_isolated_nodes() /contains_isolated_nodes() |
图中是否包含孤立节点 |
to(device) |
将图实例放置到指定的设备(GPU或CPU)上 |
clone() |
对图进行深拷贝 |
首先创建一个包含5个顶点、12条边的无向图。需要注意的是,在edge_index
中边都有有方向的,即从源节点到目标节点。若要创建从节点 v v v到节点 u u u的无向边,则需要在edge_index
中传入两条相应的边,即(u,v), (v,u)
。
import torch
import torch_geometric.data as data
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx
edge_index = torch.LongTensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4],
[1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0]])
x = torch.ones(5, 2)
g = data.Data(edge_index=edge_index, x=x)
print(g)
"""
Data(edge_index=[2, 12], x=[5, 2])
"""
# 转换为nextworkx格式的图并可视化
g = to_networkx(g)
nx.draw(g, with_labels=g.nodes)
plt.show()
创建的图可视化结果为:
对上述创建的Data
对象应用2.2节介绍的部分方法实例代码如下:
print(g.num_nodes, g.num_edges)
# 5 12
print(g.keys)
# ['x', 'edge_index']
print(g.num_node_features)
# 2
print(g.is_undirected())
# True
print(g.has_isolated_nodes())
# False
若要将自己创建的图实例保存到本地磁盘或从本地磁盘加载保存的图数据,可以使用torch.save()
和torch.load()
:
torch.save([g], "temp/data.pt")
g = torch.load("temp/data.pt")
print(g)
# [Data(edge_index=[2, 12], x=[5, 2])]
在torch_geometric.utils
模块中包含了许多对图数据的高级操作方法,下面将对其中最常用的方法进行介绍。
通过degree(index, num_nodes=None)
方法可以计算图中节点的度,其中:
index
:edge_index
中的两个维度中任意一个num_nodes
:节点的数量,可选参数示例代码:
print(degree(g.edge_index[0]))
# tensor([3., 2., 3., 2., 2.])
print(degree(g.edge_index[1]))
# tensor([3., 2., 3., 2., 2.])
自环指节点指向自身的边。在utils
中处理自环的方法包括:
contains_self_loops(edge_index)
:判断图中节点是否包含自环。remove_self_loops(edge_index)
:删除图中所有的自环。add_self_loops(edge_index)
:为图中的节点添加自环,对于有自环的节点,它会再为该节点添加一个自环。add_remaining_self_loops
:为图中还没有自环的节点添加自环。示例代码:
print(contains_self_loops(g.edge_index))
# False
edge_index, _ = add_self_loops(g.edge_index)
print(edge_index)
"""
tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 0, 1, 2, 3, 4],
[1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0, 0, 1, 2, 3, 4]])
"""
edge_index, _ = add_remaining_self_loops(edge_index)
print(edge_index)
"""
没有添加新的自环
tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 0, 1, 2, 3, 4],
[1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0, 0, 1, 2, 3, 4]])
"""
edge_index, _ = remove_self_loops(edge_index)
print(edge_index)
"""
tensor([[0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4],
[1, 2, 4, 0, 2, 1, 0, 3, 2, 4, 3, 0]])
"""
utils
中提供了若干方法用来在图中提取子图。
subgraph(subset, edge_index)
:根据给定的图节点集合subset
来抽取图中包含这些节点的子图。k_hop_subgraph(node_idx, num_hops, edge_index)
:提取给定节点集node_idx
能经过num_hops
跳到达的所有节点组成的子图(包括node_idx
本身)。sub_graph
方法示例代码:
def draw(edge_index):
graph = data.Data(edge_index=edge_index)
graph = to_networkx(graph)
print(graph.nodes)
nx.draw(graph, with_labels=graph.nodes)
plt.show()
edge_index, _ = subgraph(subset=torch.LongTensor(
[0, 1, 2]), edge_index=g.edge_index)
draw(edge_index)
提取的子图可视化如下所示:
k_hop_subgraph
方法的示例代码如下所示:
g = k_hop_subgraph(
node_idx=[0], num_hops=1, edge_index=g.edge_index)
print(g)
"""
(tensor([0, 1, 2, 4]), tensor([[0, 0, 0, 1, 1, 2, 2, 4],
[1, 2, 4, 0, 2, 1, 0, 0]]), tensor([0]), tensor([ True, True, True, True, True, True, True, False, False, False,
False, True]))
"""
从上图可以看出,该方法返回一个4元组,元组的4个元素依次为:子图的节点集、子图的边集、用来查询的节点集(中心节点集)、指示原始图g
中的边是否在子图中的布尔数组。我们取子图的边集进行可视化结果如下:
通过to_undirected(edge_index)
可以将一个图转换为无向图:
edge_index = torch.LongTensor([[0, 0], [1, 2]])
edge_index = to_undirected(edge_index)
print(edge_index)
"""
tensor([[0, 0, 1, 2],
[1, 2, 0, 0]])
"""
参考资料:
本文主要介绍了PyG中对单个图的相关操作方法,从上面的操作可以看出对于PyG对图结构的操作其实就是在操作edge_index
(该属性本来就用来在PyG中保存图的结构信息)。