Python_2019-10-07_机器视觉——Pytorch——基础网络构建

 一承二构三组建

init的
    super的init 
forward
    view

构建参数输出   

Python_2019-10-07_机器视觉——Pytorch——基础网络构建_第1张图片

import torch.nn as nn #
import torch.nn.functional as F #

# no_1:nn.Module           继承Module类
# no_2: __init__(self)      设置组件
# no_3: forward(self,x)     搭建前向网络
# 综合:1继承->2构建->3组建(一承二构三组建)
class ournet(nn.Module):
    def __init__(self):
        # nn.Module子类的函数必须在构造函数中执行父类的构造函数
        # 等价于nn.Model.__init__(self)
        super(ournet,self).__init__() 
       
        # 输入1通道,输出6通道,卷积核5*5
        self.conv1 = nn.Conv2d(1, 6, 5)
        # 定义卷积层:输入6张特征图,输出16张特征图,卷积核5x5
        self.conv2 = nn.Conv2d(6,16,5)

        # 定义全连接层:线性连接(y = Wx + b),16*5*5个节点连接到120个节点上
        # 计算: input=32*32*1 — Conv2d=28*28*6 — max_pool2d=14*14*6 — Conv2d=10*10*16 — max_pool2d=5*5*16
        self.fc1 = nn.Linear(16*5*5,120)
        
        # 定义全连接层:线性连接(y = Wx + b),120个节点连接到84个节点上
        self.fc2 = nn.Linear(120,84)
        # 定义全连接层:线性连接(y = Wx + b),84个节点连接到10个节点上
        self.fc3 = nn.Linear(84,10)

    def forward(self, x):
        # 输入x->conv1->relu->2x2窗口的最大池化->更新到x
        x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        # 输入x->conv2->relu->2x2窗口的最大池化->更新到x
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        
        # view函数将张量x变形成一维向量形式,总特征数不变,为全连接层做准备
        x = x.view(x.size()[0], -1)
        
        # 输入x->fc1->relu,更新到x
        x = F.relu(self.fc1(x))
        # 输入x->fc2->relu,更新到x
        x = F.relu(self.fc2(x))
        # 输入x->fc3,更新到x
        x = self.fc3(x)
        return x


if __name__ == "__main__":
    net = ournet()
    print("网络########################################")
    print(net)
    params = list(net.parameters())
    #print(params)
    print("len_params########################################")
    print(len(params))
    print("params########################################")
    for name, parameters in net.named_parameters():
        print(name, ":", parameters.size())

        '''
        torch.nn:网络层
        torch.nn.functional:激活函数、池化函数归于此模块
        网络主体:
        net网络要使用class并继承父类才行,因而有一些自带的方法
        net.parameters():返回全部的参数值,迭代器
        net.named_parameters():返回参数名称和值,迭代器
        net.参数名:就是参数变量,Variable,可以直接查看data和grad等等
        '''

 

你可能感兴趣的:(Python_2019-10-07_机器视觉——Pytorch——基础网络构建)