torch_geometric.nn.max_pool(
cluster,
data,
transform=None)
对由torch_geometricy .data给出的图形进行池化和粗化。
数据对象根据集群cluster中定义的集群。同一集群中的所有节点将表示为一个节点。最终节点特征由同一簇内所有节点的特征最大值定义,节点位置平均,边的index定义为同一簇内所有节点的边index的并集。
cluster (LongTensor) | 簇向量,每一个维度表示了一个点属于哪个簇 |
data (Data) | torch_geometric的data 对象 |
transform (callable, optional) | 一个函数/转换,接受粗化和池化的torch_geometry .data。数据对象,并返回转换后的版本。 |
返回torch_geometric的data 对象
假如我们一开始的data 为:
Batch(x=[9893, 1], edge_index=[2, 34637], y=[9893, 1], batch=[9893], ptr=[2])
from torch_geometric.nn import max_pool
cluster = graclus(data.edge_index, num_nodes=x.shape[0])
cluster
#tensor([ 0, 1, 1, ..., 9890, 9891, 9892])
#第i维表示第i个点在以第几个点为中心点的簇中
data_c = max_pool(
cluster,
data)
data_c
#Batch(x=[5863, 1], edge_index=[2, 21983], batch=[5863])
#分成了5863 个cluster
在mini_batch的话,需要这样写:
data_c = max_pool(
cluster,
Data(
x=data.x,
batch=data.batch,
edge_index=data.edge_index))
data
对一个cluster中中的x的特征进行最大池化操作
max_pool_x(
cluster,
x,
batch,
size: Optional[int] = None)
注意和max_pool的区别
max_pool 返回的是data,max_pool_x返回的是Tensor
max_pool 相当于max_pool_x的基础上,再对图的边进行了修改合并操作