torch.shape()三维矩阵降为二维

import torch
if __name__ == '__main__':
    a = torch.tensor([ [ [1,2,3,2,1],
                         [2,5,6,3,6] ],
                       [ [2,1,5,9,8],
                         [4,6,8,1,1] ] ],dtype=torch.float)

    a = torch.reshape(a,(-1,5))
    print(a)
tensor([[1., 2., 3., 2., 1.],
        [2., 5., 6., 3., 6.],
        [2., 1., 5., 9., 8.],
        [4., 6., 8., 1., 1.]])

你可能感兴趣的:(pytorch)