Pytorch view函数讲解

view()函数作用和reshape函数类似,就是对tensor的shape进行调整,可以通过view函数将tensor的shape调整成一个你希望的样子。

import torch

torch.manual_seed(2)
a=torch.randn(4,5)
print(a)

print(a.view(-1,2)) # 此时 -1,代表默认值,代表根据后面的列数来计算行数


'''
a
=
tensor([[-1.0408,  0.9166, -1.3042, -1.1097,  0.0299],
        [-0.0498,  1.0651,  0.8860, -0.8110,  0.6737],
        [-1.1233, -0.0919,  0.1405,  1.1191,  0.3152],
        [ 1.7528, -0.7396, -1.2425, -0.1752,  0.6990]])


a.view(-1,2)
=
tensor([[-1.0408,  0.9166],
        [-1.3042, -1.1097],
        [ 0.0299, -0.0498],
        [ 1.0651,  0.8860],
        [-0.8110,  0.6737],
        [-1.1233, -0.0919],
        [ 0.1405,  1.1191],
        [ 0.3152,  1.7528],
        [-0.7396, -1.2425],
        [-0.1752,  0.6990]])
'''

此时通过view调整,将一个 45,一共20个数字,变成了 102 (10为 view中 -1的默认数,代表按照后面的列数来计算的行数)

在训练神经网络时,经常会遇到这样的一段代码

x = x.view(x.size(0), -1)

x.size(0)指batchsize的值。这句话的出现就是为了将前面多维度的tensor展平成一维,然后再输入给 nn.Linear()结构,-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。

你可能感兴趣的:(Deep,Learning,python,深度学习)