PyG异质图神经网络NotImplementedError问题

诸神缄默不语-个人CSDN博文目录

以PyG官方的数据集和示例代码来复现一下这个问题:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero

dataset = OGB_MAG(root='files/pyg_data', preprocess='metapath2vec')
data = dataset[0]

print(data)

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

print(out)

输出信息:

HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  author={ x=[1134649, 128] },
  institution={ x=[8740, 128] },
  field_of_study={ x=[59965, 128] },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 5416271] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] }
)
my_env/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py:145: UserWarning: There exist node types ({'author'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behaviour.
  warnings.warn(
Traceback (most recent call last):
  File "try2.py", line 25, in 
    model = to_hetero(model, data.metadata(), aggr='sum')
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 118, in to_hetero
    return transformer.transform()
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/fx.py", line 157, in transform
    getattr(self, op)(node, node.target, node.name)
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 294, in call_method
    args, kwargs = self.map_args_kwargs(node, key)
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 397, in map_args_kwargs
    args = tuple(_recurse(v) for v in node.args)
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 397, in 
    args = tuple(_recurse(v) for v in node.args)
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 387, in _recurse
    raise NotImplementedError
NotImplementedError

可以很容易地看出来,这是由于有一种节点没有入边产生的问题。
解决方案就是使所有节点都有入边。如:

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero

dataset = OGB_MAG(root='files/pyg_data', preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]

print(data)

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')

with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)

print(out)

将异质图转换为无向图,这样就能得到正常的输出结果:

HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  author={ x=[1134649, 128] },
  institution={ x=[8740, 128] },
  field_of_study={ x=[59965, 128] },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 10792672] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] },
  (institution, rev_affiliated_with, author)={ edge_index=[2, 1043998] },
  (paper, rev_writes, author)={ edge_index=[2, 7145660] },
  (field_of_study, rev_has_topic, paper)={ edge_index=[2, 7505078] }
)
{'paper': tensor([[-0.8212, -0.2630, -0.7286,  ...,  1.1904,  0.1617, -0.5388],
        [-1.2484, -0.3707, -1.0336,  ...,  0.9618, -0.0373, -0.1125],
        [-0.5375,  0.0357, -0.6772,  ...,  1.2185,  0.2292, -0.2130],
        ...,
        [-0.9934, -0.2688, -0.9547,  ...,  1.3144,  0.1519, -0.2015],
        [-1.4711, -0.6607, -0.7509,  ...,  2.3383,  0.6815, -1.0679],
        [-0.4352, -0.4255, -0.6907,  ...,  1.1532,  0.1152, -0.9703]]), 'author': tensor([[-0.2782,  0.1771,  0.4187,  ..., -0.5233, -0.2969,  0.2438],
        [-0.4543,  0.1019,  0.1637,  ..., -0.7748, -0.2809,  0.2598],
        [-0.1613, -0.0481, -0.2491,  ..., -0.6227, -0.4217,  0.1335],
        ...,
        [-0.4908,  0.2382,  0.2973,  ..., -0.7266, -0.2486,  0.6449],
        [-0.2819,  0.0125,  0.9843,  ..., -1.9652, -0.4280, -0.4842],
        [-0.4236, -0.1222,  1.0246,  ..., -2.0615, -0.3246, -0.1771]]), 'institution': tensor([[ 0.3911, -1.3527, -0.6624,  ...,  0.2732,  0.5270,  0.5756],
        [ 0.1512, -0.6687, -0.6516,  ...,  0.1482,  0.2535,  0.1935],
        [ 0.1933, -1.1643, -0.4936,  ...,  0.5382,  0.3407,  0.2199],
        ...,
        [ 0.1489, -0.3021, -0.3390,  ...,  0.2690,  0.1571, -0.0781],
        [ 0.1855, -0.4848, -0.3205,  ...,  0.4728,  0.0659,  0.1500],
        [ 0.1724, -0.0682, -0.0894,  ...,  0.1189,  0.1230, -0.2249]]), 'field_of_study': tensor([[ 0.1929, -0.5402, -0.5714,  ..., -0.4296,  0.4376, -0.0660],
        [-0.2281,  0.0773, -0.0486,  ..., -0.0544, -0.2894,  0.2706],
        [-0.2798, -0.1967, -0.3376,  ..., -0.3098, -0.1610,  0.1120],
        ...,
        [ 0.0775, -0.5927, -0.6084,  ..., -0.3190,  0.2483, -0.1418],
        [ 0.0286, -0.7393, -0.6629,  ..., -0.4745,  0.8461, -0.1554],
        [-0.0804, -0.5598, -0.8517,  ..., -0.2317,  0.3234, -0.0520]])}

你可能感兴趣的:(人工智能学习笔记,PyG,图神经网络,异质图,NotImplementedE,bug)