pyg 中dataloader 注意点

三个Data 组成列表

edge_index = torch.tensor([[1, 2, 3],[0, 0, 0]], dtype=torch.long)   # 2 x E
x = torch.tensor([[1], [1], [1], [1]], dtype=torch.float)   # N x emb(in)
edge_attr = torch.tensor([[10], [20], [30]], dtype=torch.float)   # E x edge_dim

edge_index1 = torch.tensor([[2, 3, 0],[1, 1, 3]], dtype=torch.long)   # 2 x E
x1 = torch.tensor([[2], [2], [2], [2]], dtype=torch.float)   # N x emb(in)
edge_attr1 = torch.tensor([[50], [60], [70]], dtype=torch.float)   # E x edge_dim

edge_index2= torch.tensor([[0, 2, 1],[1, 3, 2]], dtype=torch.long)   # 2 x E
x2= torch.tensor([[3], [3], [3], [3]], dtype=torch.float)   # N x emb(in)
edge_attr2 = torch.tensor([[80], [90], [100]], dtype=torch.float)   # E x edge_dim

y=torch.tensor([1,2,3,4])
y1=torch.tensor([1,2,3,4])
y2=torch.tensor([1,2,3,4])
data=[]
data.append(Data(x=x, y=y, edge_attr=edge_attr, edge_index=edge_index, length=4))
data.append(Data(x=x1, y=y1, edge_attr=edge_attr1, edge_index=edge_index1, length=4))
data.append(Data(x=x2, y=y2, edge_attr=edge_attr2, edge_index=edge_index2, length=4))

把三个Data 输送到 Dataloader 中

for g in DataLoader(data,batch_size=3):

    print(g.x)
    print(g.y)
    print(g.edge_attr)
    print(g.edge_index)

最终输出

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [2.],
        [2.],
        [2.],
        [2.],
        [3.],
        [3.],
        [3.],
        [3.]])
tensor([1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4])
tensor([[ 10.],
        [ 20.],
        [ 30.],
        [ 50.],
        [ 60.],
        [ 70.],
        [ 80.],
        [ 90.],
        [100.]])
tensor([[ 1,  2,  3,  6,  7,  4,  8, 10,  9],
        [ 0,  0,  0,  5,  5,  7,  9, 11, 10]])

注意(g.edge_index) 发生了变化

你可能感兴趣的:(pyg 中dataloader 注意点)