如何定义PyTorch模型

作者:geekboys
日期:2020-3-4

PyTorch模型定义的三要素

1.必须继承nn.Module这个类,要让PyTorch知道这个类是一个Module
2.在init(self)中设置好需要的"组件"(如conv,pooling,Linear,BatchNorm等)
3.最后,在forward(self,x)中定义好的“组件”进行组装,就像搭积木,把网络结构搭建出来,这样一个模型就定义好了。

这里可以搭建一个简单的模型来体现一下这种模型搭建的方法:

#一个简单的模型
import torch
import torch.nn as nn
import torch.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()#实现父类的初始化
        self.conv1=nn.Conv2d(3,6,5)#定义卷积层组件
        self.pool1=nn.MaxPool2d(2,2)#定义池化层组件
        self.conv2=nn.Conv2dn(6,16,5)
        self.pool2=nn.MaxPool2d(2,2)
        self.fc1=nn.Linear(16*5*5,120)#定义线性连接
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)

当这些组件定义好之后,就可以定义forward()函数,用来搭建模型结构。

def forward(self,x):#x模型的输入
    x=self.pool1(F.relu(self.conv1(x)))
    x=self.pool2(F.relu(self.conv2(x)))
    x=x.view(-1,16*5*5)#表示将x进行reshape,为后面做为全连接层的输入
    x=F.relu(self.fc1(x))
    x=F.relu(self.fc2(x))
    x=self.fc3(x)
    return x

上面我们就成功的搭建了一个网络是不是很方便,当我们实例化一个模型net=Net(),然后把输入inputs扔进去,outputs=net(inputs)就可以得到输出outputs.
在PyTorch模型定义中还会经常的使用Sequetial这个组件

nn.Sequetial

torch.nn.Sequential其实就是Sequential容器,该容器将一系列操作按先后顺序给包起来,方便重复使用。
所以总结起来,PyTorch模型的定义过程为:

模型的定义就是先继承,在构建组件,最后组装

你可能感兴趣的:(如何定义PyTorch模型)