pytorch中常见的维度操作

1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展

import torch

'''
维度变换
1、view ;reshape;Flatten:维度合并和分解
2、squeeze;unsqueeze:压缩维度和增加维度(相对于维度为1的数据)
3、transpose;t;permute:维度顺顺序变换(转置)
4、expand;repeat:维度扩展
'''
a = torch.rand(4, 1, 32, 32)

'''
view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。
'''


def zqb_view():
    print(a.shape)  # torch.Size([4, 1, 32, 32])
    a1 = a.view(4, 32 * 32)
    print(a1.shape)  # torch.Size([4, 1024])
    a2 = a1.view(4, 1, 32, 32)
    print(a2.shape)  # torch.Size([4, 1, 32, 32])

    # a3 = a1.view(4,28,28) #RuntimeError: shape '[4, 28, 28]' is invalid for input of size 4096
    #     要保持输出数据与输入数据总量,防止数据污染
    # a4 = a1.view(4,32,32,1) # 逻辑错误,改变了原来数据的存储方式,虽然不会报错,但是数据已经被污染,无法正常使用

    a5 = a.view(-1, 32 * 32)  # torch.Size([4, 1024])  -1表示该维度保持不变
    print

你可能感兴趣的:(深度学习算法与模型,pytorch,深度学习,人工智能)