nn.Module
是PyTorch
提供的神经网络类,该类中实现了网络各层的定义及前向计算与反向传播机制。在实际使用时,如果想要实现某种神经网络模型,只需自定义模型类的时候继承nn.Module
,然后:
__init__()
中定义模型结构与参数;forward()
中编写网络前向过程即可。nn.Module
可以自动利用Autograd
机制实现反向传播,不需要自己手动实现。在Module
的搭建时,可以嵌套包含子Module
,这样的代码分布可以使网络更加模块化,从而提升代码的复用性。在实际的应用中,PyTorch
提供了绝大多数的网络层,如全连接、卷积网络中的卷积、池化等,并自动实现前向与反向传播。
实现一个自定义层,该层将卷积、BN以及激活函数整合成一个层:
class ConvBNReLU(nn.Module):
# 在构造函数中定义层内的结构
def __init__(self, in_channels, out_channels, ks=3, stride=1, padding=1,dilation=1, groups=1, bias=False):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ks, stride=stride,
padding=padding, dilation=dilation,groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
# 在forward方法中定义正向传播步骤
def forward(self, x):
feat = self.conv(x)
feat = self.bn(feat)
feat = self.relu(feat)
return feat
在PyTorch
中有一个子库为nn.functional
,同样也提供了很多网络层与函数功能,但与nn.Module
不同的是,利用nn.functional
定义的网络层不可自动学习参数,还需要使用nn.Parameter封装。
nn.functional
的设计初衷是对于一些不需要学习参数的层,如激活层。总体来看,对于需要学习参数的层,最好使用nn.Module
,对于无参数学习的层,可以使用nn.functional
,当然这两者间并没有严格的好坏之分。
当模型中只是简单的前馈网络时,即上一层的输出直接作为下一层的输入,这时可以采用nn.Sequential()
模块来快速搭建模型,而不必手动在forward()
函数中一层一层地前向传播。因此,如果想快速搭建模型而不考虑中间过程的话,推荐使用nn.Sequential()
模块。
其实模型类和层类本质是一样的,只是在正向传播过程中最后的输出格式可能会不一样,例如:
ConvBNRelu
层,其输出还是特征图的BCHW
格式的Tensor
;1 x num_classes
的Tensor
,每一个元素是对应所属类别的概率或者得分;B num_classes H W
的Tensor
,其宽高与输入大小相同,不过原图每个像素点对应位置的输出是num_classes
维的,表示这个像素属于每一个类的概率或者得分。实现分类网络LeNet
:
# LeNet model
class LeNet(nn.Module):
def __init__(self,num_classes:int=10,init_weights:bool=False):
super(LeNet,self).__init__()
self.features = nn.Sequential(
# conv1
nn.Conv2d(1,6,3),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2),
# conv2
nn.Conv2d(6,16,3),
nn.ReLU(inplace=True),
nn.MaxPool2d(2,2)
)
self.classifier = nn.Sequential(
nn.Linear(16*6*6,120),
nn.ReLU(inplace=True),
nn.Linear(120,84),
nn.ReLU(inplace=True),
nn.Linear(84,10),
)
def forward(self,x):
B,C,H,W = x.shape
assert H == 32 and W == 32,f"Input image size should be 32x32."
x = self.features(x)
x = torch.flatten(x,1)
x = self.classifier(x)
return x