scatter_add_() 报错Expected index [1, 67, 3] to be smaller than self [9, 66, 5003] apart from dime...

scatter_add_(dim, index, src) → Tensor函数官方链接:https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_

  1. 用torch.tensor.scatter_add_() 函数报错:Expected index [1, 67, 3] to be smaller than self [9, 66, 5003] apart from dimension 2 and to be smaller size than src [1, 67, 3]
idxes = torch.tensor([[[  59,   33,  848],
         [1818,  257, 3081],
         [   0, 3234,  320],
         [  59,   21,   16],
         [ 756,  516, 1311],
         [4990, 1286, 2835],
         [ 702, 2446, 1662],
         [ 270, 1576, 2220],
         [ 963,  201,  775],
         [   0, 3234,  320],
         [ 359, 1007, 3563],
         [4983, 3339, 2446],
         [1039, 4596, 1552],
         [ 448, 3075, 2003],
         [ 848, 1053,  407],
         [2446, 4983, 3339],
         [2236, 3056, 1059],
         [  25,  346,  940],
         [   4, 1782, 4376],
         [ 433,  475,   91],
         [ 223, 1135, 2728],
         [ 290, 2235,  610],
         [3073, 2693, 3248],
         [ 568,  426,  226],
         [2344, 2148, 2260],
         [ 601,  394, 3207],
         [   0, 3234,  320],
         [3828, 1800, 3261],
         [   0, 3234,  320],
         [1351, 4438, 1767],
         [1852, 2284, 4906],
         [4773, 3558, 1311],
         [2220, 3589, 1806],
         [3073, 2693, 3248],
         [1405,  678, 2247],
         [   0, 3234,  320],
         [2655, 2558, 3618],
         [  20, 4594, 4574],
         [  20,  775,  822],
         [ 189,  106,  102],
         [1311, 2234, 2548],
         [  93,   37,  491],
         [ 526, 1059, 2332],
         [   0, 3234,  320],
         [1282, 3268, 4381],
         [3204,  941, 4946],
         [3433, 1737, 3983],
         [2220, 1576, 3922],
         [ 642, 4518, 3075],
         [2102, 3225, 1594],
         [3728,  838, 3844],
         [1029, 2844, 2213],
         [ 739, 1025,  411],
         [3515, 4990, 4652],
         [4983, 3339, 2446],
         [ 223,   53, 3995],
         [ 408,  228,  158],
         [ 290,   33,  221],
         [ 126, 2678, 1674],
         [ 448, 2003,  253],
         [  33,  290,  221],
         [ 223,  106,  189],
         [4983, 2446, 3318],
         [3305, 1835, 4762],
         [   0, 3234,  320],
         [   0, 3234,  320],
         [   0, 3234,  320]]]).long()
probs = torch.ones([1,67,3])
tmp_trans_scores = torch.zeros([9, 66, 5003])
tmp_trans_scores.scatter_add_(2, idxes, probs)

>>>---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
 in 
----> 1 tmp_trans_scores.scatter_add_(2, idxes, probs)

RuntimeError: Expected index [1, 67, 3] to be smaller than self [9, 66, 5003] apart from dimension 2 and to be smaller size than src [1, 67, 3]

错误原因:index的每个维度都要小于src和self对应的相应的维度,所以idxes.size=[1,67,3]中,要dim0=1<9,dim1=67<66,dim2=3<5003.

你可能感兴趣的:(scatter_add_() 报错Expected index [1, 67, 3] to be smaller than self [9, 66, 5003] apart from dime...)