Pytorch_Geometric(PyG)使用DataLoader报错RuntimeError: Sizes of tensors must match except in dimension 0.

使用Pytorch_Geometric(PyG)时构建DataLoader,从DataLoader获取样本Batch时报错:RuntimeError: Sizes of tensors must match except in dimension 0.
Pytorch_Geometric(PyG)使用DataLoader报错RuntimeError: Sizes of tensors must match except in dimension 0._第1张图片

报错原因是数据对齐错误,1个batch是多个样本的集合,在样本拼接成集合时出现错误,其规律如下:

  • 使用pytorch-geometric的dataloader时,batch的各个样本合并规则
    • 属性edge_index规则特殊,每个样本edge_index为 2 × e i 2\times e_i 2×ei,则合并n个样本形成一个batch之后的batch.edge_index大小为 2 × ( ∑ i = 1 n e i ) 2\times(\sum_{i=1}^n e_i) 2×(i=1nei)
    • 其他所有属性如果为tensor,则按照第一个维度扩展,例如对于属性 x x x,第一个样本大小为 d 1 × d 2 d_1\times d_2 d1×d2,第二个样本大小为 d 3 × d 2 d_3\times d_2 d3×d2,则如果有一个batch包含这两个样本,batch.x的大小会是 ( d 3 + d 1 ) × d 2 (d_3+d_1)\times d_2 (d3+d1)×d2这里一个巨坑,要求除了第一个维度之外,其他维度大小都必须要相同!! 否则会报错RuntimeError: Sizes of tensors must match except in dimension 0.
    • 其他属性如果不是tensor,就会正常按照列表返回,batch.x=[ 样本1的x,样本2的x,样本3的x]

如何解决:

  • 如果是使用torch tensor引起的,可以考虑想办法对齐除了第一个维度外,其他维度的宽度。
  • 如果没办法对齐,使用非tensor数据类型替换,例如列表。
  • 最后的选择,指定batch_size=1以规避。dataloader=DataLoader(MyData,batch_size=1)

2022/06/23原始


2023/02/20更新
https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html
这个是官网更详细的描述,直接看这个简单

你可能感兴趣的:(pytorch,深度学习,python)