pytorch torch.view().expand().long()

print(torch.arange(0, n_class))  # n_class=60
print(torch.arange(0, n_class).view(n_class, 1, 1))  # 变换形状view(x,y,z),xyz是维度
print(torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1))  # expand(原,行增,列增)
print(torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long())

结果:

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59])


tensor([[[ 0]],

        [[ 1]],

     ......

        [[59]]])


tensor([[[ 0],
         [ 0],
         [ 0],
         [ 0],
         [ 0]],

        [[ 1],
         [ 1],
         [ 1],
         [ 1],
         [ 1]],

      ......

        [[59],
         [59],
         [59],
         [59],
         [59]]])


tensor([[[ 0],
         [ 0],
         [ 0],
         [ 0],
         [ 0]],

        [[ 1],
         [ 1],
         [ 1],
         [ 1],
         [ 1]],

   ......

        [[59],
         [59],
         [59],
         [59],
         [59]]])

你可能感兴趣的:(pytorch)