使用Pytorch 的 PyG 搭建 图神经网络 报错
can not import topk, filter_adj from torch_geometric.nn.pool.topk_pool
版本问题 语法变化
topk => SelectTopk
filter_adj => FilterEdges
from torch_geometric.nn.pool.connect import FilterEdges
from torch_geometric.nn.pool.select import SelectTopK
发现替换后不可以
于是进去看SelectTopK\FilterEdges 源码
发现里面有 topk, filter_adj 方法 但是直接 import 也不能用
于是手动写函数出来再 layers.py 里即可运行
def topk(
x: Tensor,
ratio: Optional[Union[float, int]],
batch: Tensor,
min_score: Optional[float] = None,
tol: float = 1e-7,
) -> Tensor:
if min_score is not None:
# Make sure that we do not drop all nodes in a graph.
scores_max = scatter(x, batch, reduce='max')[batch] - tol
scores_min = scores_max.clamp(max=min_score)
perm = (x > scores_min).nonzero().view(-1)
return perm
if ratio is not None:
num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')
if ratio >= 1:
k = num_nodes.new_full((num_nodes.size(0),), int(ratio))
else:
k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)
x, x_perm = torch.sort(x.view(-1), descending=True)
batch = batch[x_perm]
batch, batch_perm = torch.sort(batch, descending=False, stable=True)
arange = torch.arange(x.size(0), dtype=torch.long, device=x.device)
ptr = cumsum(num_nodes)
batched_arange = arange - ptr[batch]
mask = batched_arange < k[batch]
return x_perm[batch_perm[mask]]
def filter_adj(
edge_index: Tensor,
edge_attr: Optional[Tensor],
node_index: Tensor,
cluster_index: Optional[Tensor] = None,
num_nodes: Optional[int] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
num_nodes = maybe_num_nodes(edge_index, num_nodes)
if cluster_index is None:
cluster_index = torch.arange(node_index.size(0),
device=node_index.device)
mask = node_index.new_full((num_nodes,), -1)
mask[node_index] = cluster_index
row, col = edge_index[0], edge_index[1]
row, col = mask[row], mask[col]
mask = (row >= 0) & (col >= 0)
row, col = row[mask], col[mask]
if edge_attr is not None:
edge_attr = edge_attr[mask]
return torch.stack([row, col], dim=0), edge_attr
参考官方文档
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/pool/topk_pool.html