torch_geometric笔记:max_pool 与max_pool_x

1 max_pool 

1.1 函数介绍

torch_geometric.nn.max_pool(
    cluster, 
    data,
    transform=None)

        对由torch_geometricy .data给出的图形进行池化和粗化。

        数据对象根据集群cluster中定义的集群。同一集群中的所有节点将表示为一个节点。最终节点特征由同一簇内所有节点的特征最大值定义,节点位置平均,边的index定义为同一簇内所有节点的边index的并集。

1.2 参数说明

cluster (LongTensor 簇向量,每一个维度表示了一个点属于哪个簇
data (Data)  torch_geometric的data 对象
transform (callableoptional) 一个函数/转换,接受粗化和池化的torch_geometry .data。数据对象,并返回转换后的版本。

 返回torch_geometric的data 对象

1.3 举例说明

假如我们一开始的data 为:

Batch(x=[9893, 1], edge_index=[2, 34637], y=[9893, 1], batch=[9893], ptr=[2])
from torch_geometric.nn import max_pool

cluster = graclus(data.edge_index, num_nodes=x.shape[0])
cluster
#tensor([   0,    1,    1,  ..., 9890, 9891, 9892]) 
#第i维表示第i个点在以第几个点为中心点的簇中

data_c = max_pool(
    cluster, 
    data)
data_c
#Batch(x=[5863, 1], edge_index=[2, 21983], batch=[5863]) 
#分成了5863 个cluster

1.3.1 mini-batch 的max_pool

在mini_batch的话,需要这样写:

data_c = max_pool(
    cluster, 
    Data(
        x=data.x, 
        batch=data.batch, 
        edge_index=data.edge_index))
data

2 max_pool_x

对一个cluster中中的x的特征进行最大池化操作

max_pool_x(
    cluster, 
    x, 
    batch, 
    size: Optional[int] = None)

注意和max_pool的区别

max_pool 返回的是data,max_pool_x返回的是Tensor

max_pool 相当于max_pool_x的基础上,再对图的边进行了修改合并操作

你可能感兴趣的:(pytorch学习,python)