pytorch中view函数的常用写法解释

下图是一个简单的神经网络(不是CNN),他是用来拟合一个抛物线的,但是本文的重点是forward函数中的x.view(-1, n)函数,很多时候你会发现许多的网络都是这么写的,2个参数值里有一个-1,另外一个是正数,view函数是不改变数据的情况下任意修改Tensor的形状(你不妨认为你的Tensor是106块麻将结合到一起,你可把它摆放成一个长条的,也可以堆成堆,总之还是那些牌,还是106块)。
-1代表的意思是不确定,主要依据那个非-1的参数,在本例中,就是依靠那个n,就是只要保证列数是n即可,行数随意;如果是x.view(n, -1),那就是说我只要求行数满足n行,列数你自己安排就好。比如说Tensor有600行,n为100,那么为-1那个行数或者就是600÷100=6,是程序自己解决的。不过这里的n取值要能够整除Tensor的大小,比如你这里让n=599,就会报错,因为600÷599没法得到整数。不过大部分的例子中,执行完都是得到一个一维的Tensor(行数为1或者列数为1,好像pytorch中不太区分到底是一行n列还是n行一列,你不妨理解为“一长条”),一般用在卷积神经网络的全连接层比较多。
resize,view和reshape用法功能基本相同。

我的解释可能穿在错误或者理解不够深刻的地方,欢迎大家指正。

class MyCNN(torch.nn.Module):
    def __init__(self, in_cheng, hid_cheng, out_cheng):
        super().__init__()
        self.incheng = torch.nn.Sequential(
            torch.nn.Linear(in_cheng, hid_cheng),
            torch.nn.ReLU(),
            torch.nn.Linear(hid_cheng, out_cheng),
            torch.nn.ReLU()
        )

    def forward(self, x):
        x = self.incheng(x)
        x.view(-1, n)
        return x

你可能感兴趣的:(pytorch,神经网络)