pytorch中nn.functional()学习总结

nn.functional是一个很常用的模块,nn中的大多数layer在functional中都有一个与之对应的函数。nn.functional中的函数与nn.Module()的区别是:

  1. nn.Module实现的层(layer)是一个特殊的类,都是由class Layer(nn.Module)定义,会自动提取可学习的参数
  2. nn.functional中的函数更像是纯函数,由def functional(input)定义

如:
 

# -*- coding: utf-8 -*-
#@Time    :2019/7/4 22:15
#@Author  :XiaoMa
import torch as t
import torch.nn as nn
from torch.autograd import Variable as V

input=V(t.randn(2,3))
model=nn.Linear(3,4)
output1=model(input)    #得到输出方式1

output2=nn.functional.linear(input,model.weight,model.bias) #得到输出方式2

print(output1==output2)

b=nn.functional.relu(input)
b2=nn.ReLU(input)

print(b==b2)


输出结果:
tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]], dtype=torch.uint8)

注意:

如果模型有可学习的参数时,最好使用nn.Module;否则既可以使用nn.functional也可以使用nn.Module,二者在性能上没有太大差异,具体的使用方式取决于个人喜好。由于激活函数(ReLu、sigmoid、Tanh)、池化(MaxPool)等层没有可学习的参数,可以使用对应的functional函数,而卷积、全连接等有可学习参数的网络建议使用nn.Module。

注意:

虽然dropout没有可学习参数,但建议还是使用nn.Dropout而不是nn.functional.dropout,因为dropout在训练和测试两个阶段的行为有所差别,使用nn.Module对象能够通过model.eval操作加以区分。

示例代码:

from torch.nn import functional as F
class Net(nn.Module):
    def __init(self):
    nn.Module(self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    slf.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)    #16--上一层输出为16通道,两个5为上一层的卷积核的宽和高
                                        #所以这一层的输入大小为:16*5*5
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

    def forward(self,x):    #对父类方法的重载
        x=F.pool(F.relu(self.conv1(x)),2)
        x=F.pool(F.relu(self.conv2(x)),2)
        x=x.view(-1,16*5*5)
        x=x.relu(self.fc1(x))
        x=x.relu(self.fc2(x))
        x=self.fc3(x)

        return x

代码说明:

在代码中,不具备可学习参数的层(激活层、池化层),将它们用函数代替,这样可以不用放置在构造函数__init__中。有可学习的模块,也可以用functional代替,只不过实现起来比较繁琐,需要手动定义参数parameter,如前面实现自定义的全连接层,就可以将weight和bias两个参数单独拿出来,在构造函数中初始化为parameter。

如这种:

class MyLinear(nn.Module):
    def __init__(self):
    super(MyLinear,self).__init__():
    self.weight=nn.Parameter(t.randn(3,4))
    self.bias=nn.Parameter(t.zeros(3))
    
    def forward(self):
        return F.linear(input,weight,bias)

 

 

 

 

 

你可能感兴趣的:(deepLearning,pytorch)