三个Data 组成列表
edge_index = torch.tensor([[1, 2, 3],[0, 0, 0]], dtype=torch.long) # 2 x E
x = torch.tensor([[1], [1], [1], [1]], dtype=torch.float) # N x emb(in)
edge_attr = torch.tensor([[10], [20], [30]], dtype=torch.float) # E x edge_dim
edge_index1 = torch.tensor([[2, 3, 0],[1, 1, 3]], dtype=torch.long) # 2 x E
x1 = torch.tensor([[2], [2], [2], [2]], dtype=torch.float) # N x emb(in)
edge_attr1 = torch.tensor([[50], [60], [70]], dtype=torch.float) # E x edge_dim
edge_index2= torch.tensor([[0, 2, 1],[1, 3, 2]], dtype=torch.long) # 2 x E
x2= torch.tensor([[3], [3], [3], [3]], dtype=torch.float) # N x emb(in)
edge_attr2 = torch.tensor([[80], [90], [100]], dtype=torch.float) # E x edge_dim
y=torch.tensor([1,2,3,4])
y1=torch.tensor([1,2,3,4])
y2=torch.tensor([1,2,3,4])
data=[]
data.append(Data(x=x, y=y, edge_attr=edge_attr, edge_index=edge_index, length=4))
data.append(Data(x=x1, y=y1, edge_attr=edge_attr1, edge_index=edge_index1, length=4))
data.append(Data(x=x2, y=y2, edge_attr=edge_attr2, edge_index=edge_index2, length=4))
把三个Data 输送到 Dataloader 中
for g in DataLoader(data,batch_size=3):
print(g.x)
print(g.y)
print(g.edge_attr)
print(g.edge_index)
最终输出
tensor([[1.],
[1.],
[1.],
[1.],
[2.],
[2.],
[2.],
[2.],
[3.],
[3.],
[3.],
[3.]])
tensor([1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4])
tensor([[ 10.],
[ 20.],
[ 30.],
[ 50.],
[ 60.],
[ 70.],
[ 80.],
[ 90.],
[100.]])
tensor([[ 1, 2, 3, 6, 7, 4, 8, 10, 9],
[ 0, 0, 0, 5, 5, 7, 9, 11, 10]])
注意(g.edge_index) 发生了变化