我被这个问题折腾了好几天,因为在处理一个很庞大的图数据,需要采样训练,然后数据从原始数据处理成图数据,导入DGL采样模块中,结果就一直报错如下。
nodes = utils.prepare_tensor_dict(g, nodes, "nodes")
File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 91, in prepare_tensor_dict
return {
File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 92, in <dictcomp>
key: prepare_tensor(g, val, '{}["{}"]'.format(name, key))
File "/data1/fangmengcheng/anaconda3/envs/dgld/lib/python3.8/site-packages/dgl/utils/checks.py", line 36, in prepare_tensor
raise DGLError(
dgl._ffi.base.DGLError: Expect argument "nodes["_N"]" to have data type torch.int32. But got torch.int64.
报错意思是:dgl期望输入的图节点数据类型为torch.int32,但是实际输入的数据类型为torch.int64。
也就是dgl需要torch.int32,但是我不小心把输入的数据类型弄成了torch.int64,我一直以为是不是我数据类似问题,结果就一直从头到尾把所有数据类型都强制转换成torch.int32,结果还是报错如上,折腾了真的好几天,一直不明白为啥!!!!
结果就在刚刚,我不小心把数据类型强制转换成torch.int64,结果结果,居然跑通了!!!!
什么鬼啊这是,这个BUG怎么来的,我吐了啊,折腾了我好几天啊。。。。。
下面贴一段代码,大家可以自己试试:
import torch
import numpy as np
import dgl
us = np.random.randint(0,1000,size=[10000])
vs = np.random.randint(0,1000,size=[10000])
#注意下面强制转换输入数据为torch.int32,会报错
graph = dgl.graph((torch.tensor(us, dtype=torch.int32), torch.tensor(vs, dtype=torch.int32)))
sampler = dgl.dataloading.NeighborSampler([64, 32])
train_nids = np.arange(1000)
train_loader = dgl.dataloading.DataLoader(
graph, train_nids, sampler,
batch_size=32,
shuffle=True,
drop_last=False,
num_workers=4)
for input_nodes, output_nodes, blocks in train_loader:
print(input_nodes)
上面会报错,你只需要修改成torch.int64即可,如下。
import torch
import numpy as np
import dgl
us = np.random.randint(0,1000,size=[10000])
vs = np.random.randint(0,1000,size=[10000])
#注意下面强制转换输入数据为torch.int32,会报错
graph = dgl.graph((torch.tensor(us, dtype=torch.int64), torch.tensor(vs, dtype=torch.int64)))
sampler = dgl.dataloading.NeighborSampler([64, 32])
train_nids = np.arange(1000)
train_loader = dgl.dataloading.DataLoader(
graph, train_nids, sampler,
batch_size=32,
shuffle=True,
drop_last=False,
num_workers=4)
for input_nodes, output_nodes, blocks in train_loader:
print(input_nodes)
正常采样了。。。。。
我吐了,不带这么坑人的。。。。。