PyTorch(1) torch.nn与torch.nn.functional之间的区别和联系

原文地址

在阅读PyTorch官网的教程的时候,发现介绍如何利用Pytorch搭建一个神经网络的示例代码是这样写。

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

迷惑的地方是在于forward的函数的定义方法。为什么要把网络中的一部分层在__init__()函数里定义出来,而另一部分层则是在__forward()__函数里定义?并且一个用的是nn,另一个用的是nn.functional。同一种层的API定义有两种,这样看似冗余的设计是为了什么呢?

通过阅读源码我们会发现,nn和nn.functional之间的差别如下,我们以conv2d的定义为例。

 torch.nn.Conv2d

import torch.nn.functional as F
class Conv2d(_ConvNd):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias)

    def forward(self, input):
        return F.conv2d(input, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

torch.nn.functional.conv2d

def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
           groups=1):

    if input is not None and input.dim() != 4:
        raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()))

    f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), False,
                _pair(0), groups, torch.backends.cudnn.benchmark,
                torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
    return f(input, weight, bias)
 

从中我们可以发现,nn.Conv2d是一个类,而F.conv2d()是一个函数,而nn.Conv2d的forward()函数实现是用F.conv2d()实现的(在Module类里的__call__实现了forward()函数的调用,所以当实例化nn.Conv2d类时,forward()函数也被执行了,详细可阅读torch源码),所以两者功能并无区别,那么为什么要有这样的两种实现方式同时存在呢?

原因其实在于,为了兼顾灵活性和便利性。

在建图过程中,往往有两种层,一种如全连接层,卷积层等,当中有Variable,另一种如Pooling层,Relu层等,当中没有Variable。

如果所有的层都用nn.functional来定义,那么所有的Variable,如weights,bias等,都需要用户来手动定义,非常不方便。

而如果所有的层都换成nn来定义,那么即便是简单的计算都需要建类来做,而这些可以用更为简单的函数来代替的。
 所以在定义网络的时候,如果层内有Variable,那么用nn定义,反之,则用nn.functional定义

你可能感兴趣的:(PyTorch(1) torch.nn与torch.nn.functional之间的区别和联系)