DGL大图采样报错记录:Expect argument “nodes[“_N“]“ to have data type torch.int32. But got torch.int64.

我被这个问题折腾了好几天,因为在处理一个很庞大的图数据,需要采样训练,然后数据从原始数据处理成图数据,导入DGL采样模块中,结果就一直报错如下。

    nodes = utils.prepare_tensor_dict(g, nodes, "nodes")
  File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 91, in prepare_tensor_dict
    return {
  File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 92, in <dictcomp>
    key: prepare_tensor(g, val, '{}["{}"]'.format(name, key))
  File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 36, in prepare_tensor
    raise DGLError(
dgl._ffi.base.DGLError: Expect argument "nodes["_N"]" to have data type torch.int32. But got torch.int64.

报错意思是:dgl期望输入的图节点数据类型为torch.int32,但是实际输入的数据类型为torch.int64。

也就是dgl需要torch.int32,但是我不小心把输入的数据类型弄成了torch.int64,我一直以为是不是我数据类似问题,结果就一直从头到尾把所有数据类型都强制转换成torch.int32,结果还是报错如上,折腾了真的好几天,一直不明白为啥!!!!

结果就在刚刚,我不小心把数据类型强制转换成torch.int64,结果结果,居然跑通了!!!!

什么鬼啊这是,这个BUG怎么来的,我吐了啊,折腾了我好几天啊。。。。。

下面贴一段代码,大家可以自己试试:

import torch
import numpy as np
import dgl

us = np.random.randint(0,1000,size=[10000])
vs = np.random.randint(0,1000,size=[10000])
#注意下面强制转换输入数据为torch.int32,会报错
graph = dgl.graph((torch.tensor(us, dtype=torch.int32), torch.tensor(vs, dtype=torch.int32)))

sampler = dgl.dataloading.NeighborSampler([64, 32])

train_nids = np.arange(1000)

train_loader = dgl.dataloading.DataLoader(
    graph, train_nids, sampler,
    batch_size=32,
    shuffle=True,
    drop_last=False,
    num_workers=4)

for input_nodes, output_nodes, blocks in train_loader:
    print(input_nodes)

上面会报错,你只需要修改成torch.int64即可,如下。

import torch
import numpy as np
import dgl

us = np.random.randint(0,1000,size=[10000])
vs = np.random.randint(0,1000,size=[10000])
#注意下面强制转换输入数据为torch.int32,会报错
graph = dgl.graph((torch.tensor(us, dtype=torch.int64), torch.tensor(vs, dtype=torch.int64)))

sampler = dgl.dataloading.NeighborSampler([64, 32])

train_nids = np.arange(1000)

train_loader = dgl.dataloading.DataLoader(
    graph, train_nids, sampler,
    batch_size=32,
    shuffle=True,
    drop_last=False,
    num_workers=4)

for input_nodes, output_nodes, blocks in train_loader:
    print(input_nodes)

运行结果
DGL大图采样报错记录:Expect argument “nodes[“_N“]“ to have data type torch.int32. But got torch.int64._第1张图片

正常采样了。。。。。

我吐了,不带这么坑人的。。。。。

你可能感兴趣的:(图神经网络,工具,python,numpy,开发语言)