PyTorch之Module

torch.nn的核心数据结构是Module,它是一个抽象概念,既可以表示神经网络中的某个层(layer),也可以表示一个包含很多层的神经网络。

定义网络

定义网络时,需要继承nn.Module,并实现它的forward方法,把网络中具有可学习参数的层放在构造函数__init__中。如某一层不具有可学习参数(如relu),则既可以放在构造函数中,也可以不放

import torch.nn as nn # nn是专门为神经网络设计的模块化接口,可用来定义和运行神经网络
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        # nn.Module的子类函数必须在构造函数中执行父类的构造函数
        super(LeNet, self).__init__()   # 等价与nn.Module.__init__()
 
        # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
        # 当调用self.conv1(input)的时候,就会调用该类的forward函数
        #卷积层1表示输入图片为单通道,6表示输出通道数,(5,5)表示卷积核为5*5
        self.conv1 = nn.Conv2d(1, 6, (5, 5))   # output (N, C_{out}, H_{out}, W_{out})`
        self.conv2 = nn.Conv2d(6, 16, (5, 5))
        # 全连接层,y = Wx+b
        self.fc1 = nn.Linear(256, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
    # 卷积→激活→池化
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))  # F.max_pool2d的返回值是一个Variable
        x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
        x = x.view(x.size()[0], -1)     # -1表示自适应
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
 
        return x

net = LeNet()
print(net)
Out:

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

只要在nn.Module的子类中定义了forward函数,backward函数就会被自动实现(利用Autograd)。
网络的可学习参数通过net.parameters()返回。可学习参数和名称可以通过net.named——parameters同时返回。

params = list(net.parameters())
print(len(params))

for name,parameters in net.named_parameters():
    print(name,":",parameters.size())
Out:

10
conv1.weight : torch.Size([6, 1, 5, 5])
conv1.bias : torch.Size([6])
conv2.weight : torch.Size([16, 6, 5, 5])
conv2.bias : torch.Size([16])
fc1.weight : torch.Size([120, 256])
fc1.bias : torch.Size([120])
fc2.weight : torch.Size([84, 120])
fc2.bias : torch.Size([84])
fc3.weight : torch.Size([10, 84])
fc3.bias : torch.Size([10])

你可能感兴趣的:(python)