代码是一个使用PyTorch实现的GoogLeNet模型,该模型是一个深度卷积神经网络(CNN)用于图像分类任务。
1. 定义基本卷积模块
BasicConv2d
类是一个基本的卷积块,包含一个卷积层、批归一化层和ReLU激活函数。
该类用于构建Inception模块中的各分支。
# 导入torch库,torch是一个基于Python的科学计算框架,主要用于深度学习
import torch
# 导入torch.nn库,torch.nn是一个神经网络模块,提供了各种层和激活函数等
import torch.nn as nn
# 导入torchinfo库,torchinfo是一个用于打印模型结构和参数信息的库
from torchinfo import summary
# 定义一个基本的卷积层类,继承自nn.Module
class BasicConv2d(nn.Module):
# 定义初始化方法,接收输入通道数,输出通道数,以及其他可变参数
def __init__(self, inplanes, out_channels, **kwargs):
# 调用父类的初始化方法
super(BasicConv2d, self).__init__()
# 定义一个卷积模块,包含三个子层:卷积层,批归一化层,和ReLU激活层
self.conv = nn.Sequential(
# 定义一个卷积层,接收输入通道数,输出通道数,以及其他可变参数,不使用偏置项
nn.Conv2d(inplanes, out_channels, bias=False, **kwargs),
# 定义一个批归一化层,接收输出通道数
nn.BatchNorm2d(out_channels),
# 定义一个ReLU激活层,使用原地操作
nn.ReLU(inplace=True)
)
# 定义前向传播方法,接收输入张量x
def forward(self, x):
# 将x通过卷积模块,得到输出张量
x = self.conv(x)
# 返回输出张量
return x
2. 定义Inception层
Inception
类定义了Inception模块,包含四个分支:1x1卷积、1x1卷积后接3x3卷积、1x1卷积后接5x5卷积、最大池化后接1x1卷积。
每个分支使用 BasicConv2d
构建。
模块的输出是四个分支的拼接。
# 定义一个Inception模块类,继承自nn.Module
class Inception(nn.Module):
# 定义初始化方法,接收输入通道数,以及各个分支的输出通道数
def __init__(self, inplanes, ch1x1, ch3x3reduce, ch3x3, ch5x5reduce, ch5x5, pool_proj):
# 调用父类的初始化方法
super(Inception, self).__init__()
# 定义第一个分支,使用一个1x1的卷积层
self.branch1 = BasicConv2d(inplanes, ch1x1, kernel_size=1)
# 定义第二个分支,使用一个1x1的卷积层,后接一个3x3的卷积层,使用1的填充
self.branch2 = nn.Sequential(
BasicConv2d(inplanes, ch3x3reduce, kernel_size=1),
BasicConv2d(ch3x3reduce, ch3x3, kernel_size=3, padding=1)
)
# 定义第三个分支,使用一个1x1的卷积层,后接一个5x5的卷积层,使用2的填充
self.branch3 = nn.Sequential(
BasicConv2d(inplanes, ch5x5reduce, kernel_size=1),
BasicConv2d(ch5x5reduce, ch5x5, kernel_size=5, padding=2)
)
# 定义第四个分支,使用一个3x3的最大池化层,后接一个1x1的卷积层,使用1的填充
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
BasicConv2d(inplanes, pool_proj, kernel_size=1)
)
# 定义前向传播方法,接收输入张量x
def forward(self, x):
# 将x分别通过四个分支,得到四个输出张量
branchh1 = self.branch1(x)
branchh2 = self.branch2(x)
branchh3 = self.branch3(x)
branchh4 = self.branch4(x)
# 将四个输出张量放入一个列表中
output = [branchh1, branchh2, branchh3, branchh4]
# 沿着通道维度,将列表中的张量拼接起来,得到最终的输出张量
return torch.cat(output, 1)
3. 定义一个辅助分类器类
AuxClf
类定义了辅助分类器,用于提供额外的梯度信息。
包含特征提取部分和分类器部分。
特征提取部分包含平均池化和1x1卷积。
分类器部分包含两个全连接层。
# 定义一个辅助分类器类,继承自nn.Module
class AuxClf(nn.Module):
# 定义初始化方法,接收输入通道数,分类数,以及其他可变参数
def __init__(self, inplanes, num_classes, **kwargs):
# 调用父类的初始化方法
super(AuxClf, self).__init__()
# 定义一个特征提取模块,包含两个子层:平均池化层,和卷积层
self.feature_ = nn.Sequential(
# 定义一个平均池化层,使用5x5的核,3的步长
nn.AvgPool2d(kernel_size=5, stride=3),
# 定义一个卷积层,使用1x1的核,输出通道数为128
BasicConv2d(inplanes, 128, kernel_size=1)
)
# 定义一个分类模块,包含四个子层:线性层,ReLU激活层,Dropout层,和线性层
self.clf_ = nn.Sequential(
# 定义一个线性层,输入维度为4*4*128,输出维度为1024
nn.Linear(4*4*128, 1024),
# 定义一个ReLU激活层,使用原地操作
nn.ReLU(inplace=True),
# 定义一个Dropout层,使用0.7的丢弃率
nn.Dropout(0.7),
# 定义一个线性层,输入维度为1024,输出维度为分类数
nn.Linear(1024, num_classes)
)
# 定义前向传播方法,接收输入张量x
def forward(self, x):
# 将x通过特征提取模块,得到输出张量
x = self.feature_(x)
# 将输出张量展平为一维向量,维度为4*4*128
x = x.view(-1, 4*4*128)
# 将展平后的向量通过分类模块,得到输出向量
x = self.clf_(x)
# 返回输出向量
return x
4. 定义一个GoogLeNet类
GoogLeNet
类是整个GoogLeNet模型的定义。
包含多个阶段,每个阶段包含不同的卷积和Inception模块。
有全局平均池化、Dropout和全连接层组成的分类器。
有两个辅助分类器 (AuxClf
) 提供额外的梯度信息。
在 forward
方法中定义了模型的前向传播过程。
# 定义一个GoogLeNet类,继承自nn.Module
class GoogLeNet(nn.Module):
# 定义初始化方法,接受一个参数num_classes,表示分类的类别数,默认为1000
def __init__(self, num_classes: int = 1000):
# 调用父类的初始化方法
super(GoogLeNet, self).__init__()
# 定义第一个卷积层,输入通道为3,输出通道为64,卷积核大小为7,步长为2,边缘填充为3
self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) # 输出112x112x64
# 定义第一个最大池化层,池化核大小为3,步长为2,使用向上取整的方式处理边缘
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) # (112 - 3) / 2 + 1 = 55.5,向上取整,输出56x56x64
# 定义第二个卷积层,输入通道为64,输出通道为64,卷积核大小为1,步长为1
self.conv2 = BasicConv2d(64, 64, kernel_size=1, stride=1)
# 定义第三个卷积层,输入通道为64,输出通道为192,卷积核大小为3,步长为1,边缘填充为1
self.conv3 = BasicConv2d(64, 192, kernel_size=3, stride=1, padding=1) # 输出56x56x192
# 定义第二个最大池化层,池化核大小为3,步长为2,使用向上取整的方式处理边缘
self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) # 输出28x28x192
# 定义第一个Inception模块,输入通道为192,输出通道为64,96,128,16,32,32
self.inception3a = Inception(192,64,96, 128, 16, 32, 32) # 输出28x28x256
# 定义第二个Inception模块,输入通道为256,输出通道为128,128,192,32,96,64
self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) # 输出28x28x480
# 定义第三个最大池化层,池化核大小为3,步长为2,使用向上取整的方式处理边缘
self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) # 输出14x14x480
# 定义第三个Inception模块,输入通道为480,输出通道为192,96,208,16,48,64
self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) # 输出14x14x512
# 定义第四个Inception模块,输入通道为512,输出通道为160,112,224,24,64,64
self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) # 输出14x14x512
# 定义第五个Inception模块,输入通道为512,输出通道为128,128,256,24,64,64
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) # 输出14x14x512
# 定义第六个Inception模块,输入通道为512,输出通道为112,144,288,32,64,64
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) # 输出14x14x528
# 定义第七个Inception模块,输入通道为528,输出通道为256,160,320,32,128,128
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) # 输出14x14x832
# 定义第四个最大池化层,池化核大小为3,步长为2,使用向上取整的方式处理边缘
self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) # 输出7x7x832
# 定义第八个Inception模块,输入通道为832,输出通道为256,160,320,32,128,128
self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) # 输出7x7x832
# 定义第九个Inception模块,输入通道为832,输出通道为384,192,384,48,128,128
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) # 输出7x7x1024
# 定义自适应平均池化层,输出大小为1x1
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # 输出1x1x1024
# 定义一个0.4的随机失活层
self.dropout = nn.Dropout(0.4)
# 定义一个全连接层,输入维度为1024,输出维度为num_classes
self.fc = nn.Linear(1024, num_classes)
# 定义第一个辅助分类器,输入通道为512,输出通道为num_classes
self.aux1 = AuxClf(512, num_classes)
# 定义第二个辅助分类器,输入通道为528,输出通道为num_classes
self.aux2 = AuxClf(528, num_classes)
# 定义前向传播方法,接受一个参数x,表示输入的数据
def forward(self, x): # 前向传播函数
x = self.maxpool1(self.conv1(x)) # 第一层卷积和池化
x = self.maxpool2(self.conv3(self.conv2(x))) # 第二层和第三层卷积,然后池化
x = self.maxpool3(self.inception3b(self.inception3a(x))) # 两个Inception模块,然后池化
x = self.inception4a(x) # 第三个Inception模块
aux1 = self.aux1(x) # 第一个辅助分类器的输出
x = self.inception4b(x) # 第四个Inception模块
x = self.inception4c(x) # 第五个Inception模块
x = self.inception4d(x) # 第六个Inception模块
aux2 = self.aux2(x) # 第二个辅助分类器的输出
x = self.maxpool4(self.inception4e(x)) # 第七个Inception模块,然后池化
x = self.inception5b(self.inception5a(x)) # 最后两个Inception模块
x = self.avgpool(x) # 平均池化
x = torch.flatten(x, 1) # 展平
x = self.dropout(x) # Dropout
x = self.fc(x) # 全连接层
return x, aux1, aux2 # 返回主分类器和两个辅助分类器的输出
5. 主函数
主函数创建了一个模拟输入数据,并实例化了 GoogLeNet
模型。
调用模型进行前向传播,并输出主分类器以及两个辅助分类器的输出形状。
使用 torchinfo.summary
打印模型的详细信息。
# 如果是主程序
if __name__ == '__main__':
# 创建一个全1的张量,形状为20x3x224x224,表示有20个样本,每个样本有3个通道,高和宽都是224
data = torch.ones(20, 3, 224, 224)
# 创建一个GoogLeNet实例
net = GoogLeNet()
# 将数据输入到模型中,得到模型的输出和两个辅助分类器的输出
x, aux1, aux2 = net(data)
# 遍历三个输出,打印它们的形状
for i in [x, aux1, aux2]:
print(i.shape)
# 使用torchinfo库的summary函数,打印模型的参数统计,输入形状为20x3x224x224,设备为cpu
summary(net, (20, 3, 224, 224), device='cpu')
总体而言,该代码实现了一个完整的GoogLeNet模型,用于图像分类任务,并包括了辅助分类器以提高训练效果。
输出结果:模型的输出和两个辅助分类器的输出、模型的参数统计
torch.Size([20, 1000])
torch.Size([20, 1000])
torch.Size([20, 1000])
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
GoogLeNet [20, 1000] --
├─BasicConv2d: 1-1 [20, 64, 112, 112] --
│ └─Sequential: 2-1 [20, 64, 112, 112] --
│ │ └─Conv2d: 3-1 [20, 64, 112, 112] 9,408
│ │ └─BatchNorm2d: 3-2 [20, 64, 112, 112] 128
│ │ └─ReLU: 3-3 [20, 64, 112, 112] --
├─MaxPool2d: 1-2 [20, 64, 56, 56] --
├─BasicConv2d: 1-3 [20, 64, 56, 56] --
│ └─Sequential: 2-2 [20, 64, 56, 56] --
│ │ └─Conv2d: 3-4 [20, 64, 56, 56] 4,096
│ │ └─BatchNorm2d: 3-5 [20, 64, 56, 56] 128
│ │ └─ReLU: 3-6 [20, 64, 56, 56] --
├─BasicConv2d: 1-4 [20, 192, 56, 56] --
│ └─Sequential: 2-3 [20, 192, 56, 56] --
│ │ └─Conv2d: 3-7 [20, 192, 56, 56] 110,592
│ │ └─BatchNorm2d: 3-8 [20, 192, 56, 56] 384
│ │ └─ReLU: 3-9 [20, 192, 56, 56] --
├─MaxPool2d: 1-5 [20, 192, 28, 28] --
├─Inception: 1-6 [20, 256, 28, 28] --
│ └─BasicConv2d: 2-4 [20, 64, 28, 28] --
│ │ └─Sequential: 3-10 [20, 64, 28, 28] 12,416
│ └─Sequential: 2-5 [20, 128, 28, 28] --
│ │ └─BasicConv2d: 3-11 [20, 96, 28, 28] 18,624
│ │ └─BasicConv2d: 3-12 [20, 128, 28, 28] 110,848
│ └─Sequential: 2-6 [20, 32, 28, 28] --
│ │ └─BasicConv2d: 3-13 [20, 16, 28, 28] 3,104
│ │ └─BasicConv2d: 3-14 [20, 32, 28, 28] 12,864
│ └─Sequential: 2-7 [20, 32, 28, 28] --
│ │ └─MaxPool2d: 3-15 [20, 192, 28, 28] --
│ │ └─BasicConv2d: 3-16 [20, 32, 28, 28] 6,208
├─Inception: 1-7 [20, 480, 28, 28] --
│ └─BasicConv2d: 2-8 [20, 128, 28, 28] --
│ │ └─Sequential: 3-17 [20, 128, 28, 28] 33,024
│ └─Sequential: 2-9 [20, 192, 28, 28] --
│ │ └─BasicConv2d: 3-18 [20, 128, 28, 28] 33,024
│ │ └─BasicConv2d: 3-19 [20, 192, 28, 28] 221,568
│ └─Sequential: 2-10 [20, 96, 28, 28] --
│ │ └─BasicConv2d: 3-20 [20, 32, 28, 28] 8,256
│ │ └─BasicConv2d: 3-21 [20, 96, 28, 28] 76,992
│ └─Sequential: 2-11 [20, 64, 28, 28] --
│ │ └─MaxPool2d: 3-22 [20, 256, 28, 28] --
│ │ └─BasicConv2d: 3-23 [20, 64, 28, 28] 16,512
├─MaxPool2d: 1-8 [20, 480, 14, 14] --
├─Inception: 1-9 [20, 512, 14, 14] --
│ └─BasicConv2d: 2-12 [20, 192, 14, 14] --
│ │ └─Sequential: 3-24 [20, 192, 14, 14] 92,544
│ └─Sequential: 2-13 [20, 208, 14, 14] --
│ │ └─BasicConv2d: 3-25 [20, 96, 14, 14] 46,272
│ │ └─BasicConv2d: 3-26 [20, 208, 14, 14] 180,128
│ └─Sequential: 2-14 [20, 48, 14, 14] --
│ │ └─BasicConv2d: 3-27 [20, 16, 14, 14] 7,712
│ │ └─BasicConv2d: 3-28 [20, 48, 14, 14] 19,296
│ └─Sequential: 2-15 [20, 64, 14, 14] --
│ │ └─MaxPool2d: 3-29 [20, 480, 14, 14] --
│ │ └─BasicConv2d: 3-30 [20, 64, 14, 14] 30,848
├─AuxClf: 1-10 [20, 1000] --
│ └─Sequential: 2-16 [20, 128, 4, 4] --
│ │ └─AvgPool2d: 3-31 [20, 512, 4, 4] --
│ │ └─BasicConv2d: 3-32 [20, 128, 4, 4] 65,792
│ └─Sequential: 2-17 [20, 1000] --
│ │ └─Linear: 3-33 [20, 1024] 2,098,176
│ │ └─ReLU: 3-34 [20, 1024] --
│ │ └─Dropout: 3-35 [20, 1024] --
│ │ └─Linear: 3-36 [20, 1000] 1,025,000
├─Inception: 1-11 [20, 512, 14, 14] --
│ └─BasicConv2d: 2-18 [20, 160, 14, 14] --
│ │ └─Sequential: 3-37 [20, 160, 14, 14] 82,240
│ └─Sequential: 2-19 [20, 224, 14, 14] --
│ │ └─BasicConv2d: 3-38 [20, 112, 14, 14] 57,568
│ │ └─BasicConv2d: 3-39 [20, 224, 14, 14] 226,240
│ └─Sequential: 2-20 [20, 64, 14, 14] --
│ │ └─BasicConv2d: 3-40 [20, 24, 14, 14] 12,336
│ │ └─BasicConv2d: 3-41 [20, 64, 14, 14] 38,528
│ └─Sequential: 2-21 [20, 64, 14, 14] --
│ │ └─MaxPool2d: 3-42 [20, 512, 14, 14] --
│ │ └─BasicConv2d: 3-43 [20, 64, 14, 14] 32,896
├─Inception: 1-12 [20, 512, 14, 14] --
│ └─BasicConv2d: 2-22 [20, 128, 14, 14] --
│ │ └─Sequential: 3-44 [20, 128, 14, 14] 65,792
│ └─Sequential: 2-23 [20, 256, 14, 14] --
│ │ └─BasicConv2d: 3-45 [20, 128, 14, 14] 65,792
│ │ └─BasicConv2d: 3-46 [20, 256, 14, 14] 295,424
│ └─Sequential: 2-24 [20, 64, 14, 14] --
│ │ └─BasicConv2d: 3-47 [20, 24, 14, 14] 12,336
│ │ └─BasicConv2d: 3-48 [20, 64, 14, 14] 38,528
│ └─Sequential: 2-25 [20, 64, 14, 14] --
│ │ └─MaxPool2d: 3-49 [20, 512, 14, 14] --
│ │ └─BasicConv2d: 3-50 [20, 64, 14, 14] 32,896
├─Inception: 1-13 [20, 528, 14, 14] --
│ └─BasicConv2d: 2-26 [20, 112, 14, 14] --
│ │ └─Sequential: 3-51 [20, 112, 14, 14] 57,568
│ └─Sequential: 2-27 [20, 288, 14, 14] --
│ │ └─BasicConv2d: 3-52 [20, 144, 14, 14] 74,016
│ │ └─BasicConv2d: 3-53 [20, 288, 14, 14] 373,824
│ └─Sequential: 2-28 [20, 64, 14, 14] --
│ │ └─BasicConv2d: 3-54 [20, 32, 14, 14] 16,448
│ │ └─BasicConv2d: 3-55 [20, 64, 14, 14] 51,328
│ └─Sequential: 2-29 [20, 64, 14, 14] --
│ │ └─MaxPool2d: 3-56 [20, 512, 14, 14] --
│ │ └─BasicConv2d: 3-57 [20, 64, 14, 14] 32,896
├─AuxClf: 1-14 [20, 1000] --
│ └─Sequential: 2-30 [20, 128, 4, 4] --
│ │ └─AvgPool2d: 3-58 [20, 528, 4, 4] --
│ │ └─BasicConv2d: 3-59 [20, 128, 4, 4] 67,840
│ └─Sequential: 2-31 [20, 1000] --
│ │ └─Linear: 3-60 [20, 1024] 2,098,176
│ │ └─ReLU: 3-61 [20, 1024] --
│ │ └─Dropout: 3-62 [20, 1024] --
│ │ └─Linear: 3-63 [20, 1000] 1,025,000
├─Inception: 1-15 [20, 832, 14, 14] --
│ └─BasicConv2d: 2-32 [20, 256, 14, 14] --
│ │ └─Sequential: 3-64 [20, 256, 14, 14] 135,680
│ └─Sequential: 2-33 [20, 320, 14, 14] --
│ │ └─BasicConv2d: 3-65 [20, 160, 14, 14] 84,800
│ │ └─BasicConv2d: 3-66 [20, 320, 14, 14] 461,440
│ └─Sequential: 2-34 [20, 128, 14, 14] --
│ └─BasicConv2d: 2-36 [20, 256, 7, 7] --
│ │ └─Sequential: 3-71 [20, 256, 7, 7] 213,504
│ └─Sequential: 2-37 [20, 320, 7, 7] --
│ │ └─BasicConv2d: 3-72 [20, 160, 7, 7] 133,440
│ │ └─BasicConv2d: 3-73 [20, 320, 7, 7] 461,440
│ └─Sequential: 2-38 [20, 128, 7, 7] --
│ │ └─BasicConv2d: 3-74 [20, 32, 7, 7] 26,688
│ │ └─BasicConv2d: 3-75 [20, 128, 7, 7] 102,656
│ └─Sequential: 2-39 [20, 128, 7, 7] --
│ │ └─MaxPool2d: 3-76 [20, 832, 7, 7] --
│ │ └─BasicConv2d: 3-77 [20, 128, 7, 7] 106,752
├─Inception: 1-18 [20, 1024, 7, 7] --
│ └─BasicConv2d: 2-40 [20, 384, 7, 7] --
│ │ └─Sequential: 3-78 [20, 384, 7, 7] 320,256
│ └─Sequential: 2-41 [20, 384, 7, 7] --
│ │ └─BasicConv2d: 3-79 [20, 192, 7, 7] 160,128
│ │ └─BasicConv2d: 3-80 [20, 384, 7, 7] 664,320
│ └─Sequential: 2-42 [20, 128, 7, 7] --
│ │ └─BasicConv2d: 3-81 [20, 48, 7, 7] 40,032
│ │ └─BasicConv2d: 3-82 [20, 128, 7, 7] 153,856
│ └─Sequential: 2-43 [20, 128, 7, 7] --
│ │ └─MaxPool2d: 3-83 [20, 832, 7, 7] --
│ │ └─BasicConv2d: 3-84 [20, 128, 7, 7] 106,752
├─AdaptiveAvgPool2d: 1-19 [20, 1024, 1, 1] --
├─Dropout: 1-20 [20, 1024] --
├─Linear: 1-21 [20, 1000] 1,025,000
===============================================================================================
Total params: 13,385,816
Trainable params: 13,385,816
Non-trainable params: 0
Total mult-adds (G): 31.82
===============================================================================================Input size (MB): 12.04
Forward/backward pass size (MB): 1034.49
Params size (MB): 53.54
Estimated Total Size (MB): 1100.08
===============================================================================================
GoogLeNet 网络中LocalRespNorm
GoogLeNet网络中的(Batch Normalization)批量归一化
BN(批量归一化)和LRN(局部归一化)有什么区别?
MaxPool2d类
定义class GoogLeNet(nn.Module)类时,需要知道上一层的输出大小,才能初始化下一层的参数
解释 summary(net, (20, 3, 224, 224), device='cpu')
如何计算conv1输出层的特征宽高
如何计算池化层的输出特征图大小?
参考网址:
https://wxler.github.io/2020/11/24/223407/ 深入解读GoogLeNet网络结构(附代码实现) | Layne's Blog (wxler.github.io)