[ 图像分类 ] 经典网络模型2——GoogLeNet 详解与复现

Author :Horizon Max


  • GoogLeNet
  • GoogLeNet 详解
    • Inception module
      • 核心思想
    • GoogLeNet 网络结构
  • GoogLeNet 复现


GoogLeNet 是2014年 Christian Szegedy 提出的一种 全新 的深度学习结构 ;

在这之前的 AlexNet 、VGG 等结构都是通过 增大网络深度 来获得更好的训练效果 ;

但层数的增加会带来很多负作用,如 overfit梯度消失梯度爆炸 等 ;

作为2014年 ImageNet 挑战赛(ILSVRC) classification 任务中的 冠军,超过了同年的 VGG 网络模型 ;

模型参数量较12年冠军 AlexNet 只有其 1/12 ;

叫 GoogLeNet 而不是 GoogleNet 是为了致敬之前的 LeNet 网络 ;

论文地址:Going deeper with convolutions

[ 图像分类 ] 经典网络模型2——GoogLeNet 详解与复现_第1张图片

GoogLeNet 详解

Inception module

保持 计算成本 的同时增加网络的 深度和宽度 ,从而提升训练结果 ;

文章中提到的简单的增加网络的深度和宽度会有两个缺点 :

(1)尺寸越大,参数的数量越多,网络更容易发生过拟合 ;
(2)计算资源的使用显著增加 ;




(1)使用 1x1 的卷积来进行 升降维
(2)在多个尺寸上 同时进行卷积再聚合

[ 图像分类 ] 经典网络模型2——GoogLeNet 详解与复现_第2张图片

结构(a)由于 5 x 5 卷积 的存在会导致参数量的增加 ;
基于此提出了结构(b),在 3 x 3 和 5 x 5 卷积之前使用 1 x 1 卷积

1 x 1 卷积:

受到 Lin 等人提出的 Network-in-Network 的影响,为提高神经网络的表示能力,每次的 1 x 1 卷积都使用校正线性激活;
它们主要用作 降维 模块,以消除计算瓶颈,同时达到增加 网络深度和宽度 的目的,而不会造成显著的性能损失;

GoogLeNet 网络结构

在 Inception (4a) 和 (4d) 模块设置了 softmax 输出作为辅助网络用来 梯度传播
将其损失 与 具有折扣权值的网络 总损失相加(辅助分类器的损失加权为 0.3);

输入大小为(224 x 224 x 3)
conv1: 卷积(7 x 7)、步长为 2 , 输出(112 x 112 x 64)
maxpool1: 卷积(3 x 3)、步长为 2 ,输出(56 x 56 x 64)
conv2: 卷积(1 x 1)、步长为 1 ,输出(56 x 56 x 64)
conv3: 卷积(3 x 3)、步长为 1 ,输出(56 x 56 x 192)
maxpool2: 卷积(3 x 3)、步长为 2 ,输出(28 x 28 x 192)

inception3a: 输出(28 x 28 x 256)
inception3b: 输出(28 x 28 x 480)
maxpool3: 输出(14 x 14 x 480)

inception4a: 输出(14 x 14 x 512)
inception4b: 输出(14 x 14 x 512)
inception4c: 输出(14 x 14 x 512)
inception4d: 输出(14 x 14 x 528)
inception4e: 输出(14 x 14 x 832)
maxpool4: 输出(7 x 7 x 832)

inception5a: 输出(7 x 7 x 832)
inception5b: 输出(7 x 7 x 1024)

aovgpol: 输出(1 x 1 x 1024)
dropout: 40%
Linear: num_classes


[ 图像分类 ] 经典网络模型2——GoogLeNet 详解与复现_第3张图片

GoogLeNet 复现

# Here is the code :

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):        # **kwargs: 不定长度键值对
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()
        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        outputs = torch.cat((branch1, branch2, branch3, branch4), 1)
        return outputs

class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.avgPool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.avgPool(x)
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.7, training=self.training)
        out = self.fc2(x)
        return out

class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)
            self.aux1 = None
            self.aux2 = None

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                X = stats.truncnorm(-2, 2, scale=0.01)
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                values = values.view(m.weight.size())
                with torch.no_grad():
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        if self.training and self.aux_logits:
            aux1 = self.aux1(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        if self.training and self.aux_logits:
            aux2 = self.aux2(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        out = self.fc(x)

        if self.training and self.aux_logits:
            return out, aux1, aux2
            return out

def test():
    net = GoogLeNet()
    out, aux1, aux2 = net(torch.randn(1, 3, 224, 224))
    summary(net, (1, 3, 224, 224))

if __name__ == '__main__':


torch.Size([1, 1000])
Layer (type:depth-idx)                   Output Shape              Param #
GoogLeNet                                --                        --
├─BasicConv2d: 1-1                       [1, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [1, 64, 112, 112]         128
├─MaxPool2d: 1-2                         [1, 64, 56, 56]           --
├─BasicConv2d: 1-3                       [1, 64, 56, 56]           --
│    └─Conv2d: 2-3                       [1, 64, 56, 56]           4,096
│    └─BatchNorm2d: 2-4                  [1, 64, 56, 56]           128
├─BasicConv2d: 1-4                       [1, 192, 56, 56]          --
│    └─Conv2d: 2-5                       [1, 192, 56, 56]          110,592
│    └─BatchNorm2d: 2-6                  [1, 192, 56, 56]          384
├─MaxPool2d: 1-5                         [1, 192, 28, 28]          --
├─Inception: 1-6                         [1, 256, 28, 28]          --
│    └─BasicConv2d: 2-7                  [1, 64, 28, 28]           --
│    │    └─Conv2d: 3-1                  [1, 64, 28, 28]           12,288
│    │    └─BatchNorm2d: 3-2             [1, 64, 28, 28]           128
│    └─Sequential: 2-8                   [1, 128, 28, 28]          --
│    │    └─BasicConv2d: 3-3             [1, 96, 28, 28]           18,624
│    │    └─BasicConv2d: 3-4             [1, 128, 28, 28]          110,848
│    └─Sequential: 2-9                   [1, 32, 28, 28]           --
│    │    └─BasicConv2d: 3-5             [1, 16, 28, 28]           3,104
│    │    └─BasicConv2d: 3-6             [1, 32, 28, 28]           4,672
│    └─Sequential: 2-10                  [1, 32, 28, 28]           --
│    │    └─MaxPool2d: 3-7               [1, 192, 28, 28]          --
│    │    └─BasicConv2d: 3-8             [1, 32, 28, 28]           6,208
├─Inception: 1-7                         [1, 480, 28, 28]          --
│    └─BasicConv2d: 2-11                 [1, 128, 28, 28]          --
│    │    └─Conv2d: 3-9                  [1, 128, 28, 28]          32,768
│    │    └─BatchNorm2d: 3-10            [1, 128, 28, 28]          256
│    └─Sequential: 2-12                  [1, 192, 28, 28]          --
│    │    └─BasicConv2d: 3-11            [1, 128, 28, 28]          33,024
│    │    └─BasicConv2d: 3-12            [1, 192, 28, 28]          221,568
│    └─Sequential: 2-13                  [1, 96, 28, 28]           --
│    │    └─BasicConv2d: 3-13            [1, 32, 28, 28]           8,256
│    │    └─BasicConv2d: 3-14            [1, 96, 28, 28]           27,840
│    └─Sequential: 2-14                  [1, 64, 28, 28]           --
│    │    └─MaxPool2d: 3-15              [1, 256, 28, 28]          --
│    │    └─BasicConv2d: 3-16            [1, 64, 28, 28]           16,512
├─MaxPool2d: 1-8                         [1, 480, 14, 14]          --
├─Inception: 1-9                         [1, 512, 14, 14]          --
│    └─BasicConv2d: 2-15                 [1, 192, 14, 14]          --
│    │    └─Conv2d: 3-17                 [1, 192, 14, 14]          92,160
│    │    └─BatchNorm2d: 3-18            [1, 192, 14, 14]          384
│    └─Sequential: 2-16                  [1, 208, 14, 14]          --
│    │    └─BasicConv2d: 3-19            [1, 96, 14, 14]           46,272
│    │    └─BasicConv2d: 3-20            [1, 208, 14, 14]          180,128
│    └─Sequential: 2-17                  [1, 48, 14, 14]           --
│    │    └─BasicConv2d: 3-21            [1, 16, 14, 14]           7,712
│    │    └─BasicConv2d: 3-22            [1, 48, 14, 14]           7,008
│    └─Sequential: 2-18                  [1, 64, 14, 14]           --
│    │    └─MaxPool2d: 3-23              [1, 480, 14, 14]          --
│    │    └─BasicConv2d: 3-24            [1, 64, 14, 14]           30,848
├─Inception: 1-10                        [1, 512, 14, 14]          --
│    └─BasicConv2d: 2-19                 [1, 160, 14, 14]          --
│    │    └─Conv2d: 3-25                 [1, 160, 14, 14]          81,920
│    │    └─BatchNorm2d: 3-26            [1, 160, 14, 14]          320
│    └─Sequential: 2-20                  [1, 224, 14, 14]          --
│    │    └─BasicConv2d: 3-27            [1, 112, 14, 14]          57,568
│    │    └─BasicConv2d: 3-28            [1, 224, 14, 14]          226,240
│    └─Sequential: 2-21                  [1, 64, 14, 14]           --
│    │    └─BasicConv2d: 3-29            [1, 24, 14, 14]           12,336
│    │    └─BasicConv2d: 3-30            [1, 64, 14, 14]           13,952
│    └─Sequential: 2-22                  [1, 64, 14, 14]           --
│    │    └─MaxPool2d: 3-31              [1, 512, 14, 14]          --
│    │    └─BasicConv2d: 3-32            [1, 64, 14, 14]           32,896
├─Inception: 1-11                        [1, 512, 14, 14]          --
│    └─BasicConv2d: 2-23                 [1, 128, 14, 14]          --
│    │    └─Conv2d: 3-33                 [1, 128, 14, 14]          65,536
│    │    └─BatchNorm2d: 3-34            [1, 128, 14, 14]          256
│    └─Sequential: 2-24                  [1, 256, 14, 14]          --
│    │    └─BasicConv2d: 3-35            [1, 128, 14, 14]          65,792
│    │    └─BasicConv2d: 3-36            [1, 256, 14, 14]          295,424
│    └─Sequential: 2-25                  [1, 64, 14, 14]           --
│    │    └─BasicConv2d: 3-37            [1, 24, 14, 14]           12,336
│    │    └─BasicConv2d: 3-38            [1, 64, 14, 14]           13,952
│    └─Sequential: 2-26                  [1, 64, 14, 14]           --
│    │    └─MaxPool2d: 3-39              [1, 512, 14, 14]          --
│    │    └─BasicConv2d: 3-40            [1, 64, 14, 14]           32,896
├─Inception: 1-12                        [1, 528, 14, 14]          --
│    └─BasicConv2d: 2-27                 [1, 112, 14, 14]          --
│    │    └─Conv2d: 3-41                 [1, 112, 14, 14]          57,344
│    │    └─BatchNorm2d: 3-42            [1, 112, 14, 14]          224
│    └─Sequential: 2-28                  [1, 288, 14, 14]          --
│    │    └─BasicConv2d: 3-43            [1, 144, 14, 14]          74,016
│    │    └─BasicConv2d: 3-44            [1, 288, 14, 14]          373,824
│    └─Sequential: 2-29                  [1, 64, 14, 14]           --
│    │    └─BasicConv2d: 3-45            [1, 32, 14, 14]           16,448
│    │    └─BasicConv2d: 3-46            [1, 64, 14, 14]           18,560
│    └─Sequential: 2-30                  [1, 64, 14, 14]           --
│    │    └─MaxPool2d: 3-47              [1, 512, 14, 14]          --
│    │    └─BasicConv2d: 3-48            [1, 64, 14, 14]           32,896
├─Inception: 1-13                        [1, 832, 14, 14]          --
│    └─BasicConv2d: 2-31                 [1, 256, 14, 14]          --
│    │    └─Conv2d: 3-49                 [1, 256, 14, 14]          135,168
│    │    └─BatchNorm2d: 3-50            [1, 256, 14, 14]          512
│    └─Sequential: 2-32                  [1, 320, 14, 14]          --
│    │    └─BasicConv2d: 3-51            [1, 160, 14, 14]          84,800
│    │    └─BasicConv2d: 3-52            [1, 320, 14, 14]          461,440
│    └─Sequential: 2-33                  [1, 128, 14, 14]          --
│    │    └─BasicConv2d: 3-53            [1, 32, 14, 14]           16,960
│    │    └─BasicConv2d: 3-54            [1, 128, 14, 14]          37,120
│    └─Sequential: 2-34                  [1, 128, 14, 14]          --
│    │    └─MaxPool2d: 3-55              [1, 528, 14, 14]          --
│    │    └─BasicConv2d: 3-56            [1, 128, 14, 14]          67,840
├─MaxPool2d: 1-14                        [1, 832, 7, 7]            --
├─Inception: 1-15                        [1, 832, 7, 7]            --
│    └─BasicConv2d: 2-35                 [1, 256, 7, 7]            --
│    │    └─Conv2d: 3-57                 [1, 256, 7, 7]            212,992
│    │    └─BatchNorm2d: 3-58            [1, 256, 7, 7]            512
│    └─Sequential: 2-36                  [1, 320, 7, 7]            --
│    │    └─BasicConv2d: 3-59            [1, 160, 7, 7]            133,440
│    │    └─BasicConv2d: 3-60            [1, 320, 7, 7]            461,440
│    └─Sequential: 2-37                  [1, 128, 7, 7]            --
│    │    └─BasicConv2d: 3-61            [1, 32, 7, 7]             26,688
│    │    └─BasicConv2d: 3-62            [1, 128, 7, 7]            37,120
│    └─Sequential: 2-38                  [1, 128, 7, 7]            --
│    │    └─MaxPool2d: 3-63              [1, 832, 7, 7]            --
│    │    └─BasicConv2d: 3-64            [1, 128, 7, 7]            106,752
├─Inception: 1-16                        [1, 1024, 7, 7]           --
│    └─BasicConv2d: 2-39                 [1, 384, 7, 7]            --
│    │    └─Conv2d: 3-65                 [1, 384, 7, 7]            319,488
│    │    └─BatchNorm2d: 3-66            [1, 384, 7, 7]            768
│    └─Sequential: 2-40                  [1, 384, 7, 7]            --
│    │    └─BasicConv2d: 3-67            [1, 192, 7, 7]            160,128
│    │    └─BasicConv2d: 3-68            [1, 384, 7, 7]            664,320
│    └─Sequential: 2-41                  [1, 128, 7, 7]            --
│    │    └─BasicConv2d: 3-69            [1, 48, 7, 7]             40,032
│    │    └─BasicConv2d: 3-70            [1, 128, 7, 7]            55,552
│    └─Sequential: 2-42                  [1, 128, 7, 7]            --
│    │    └─MaxPool2d: 3-71              [1, 832, 7, 7]            --
│    │    └─BasicConv2d: 3-72            [1, 128, 7, 7]            106,752
├─AdaptiveAvgPool2d: 1-17                [1, 1024, 1, 1]           --
├─Dropout: 1-18                          [1, 1024]                 --
├─Linear: 1-19                           [1, 1000]                 1,025,000
Total params: 6,624,904
Trainable params: 6,624,904
Non-trainable params: 0
Total mult-adds (G): 1.50
Input size (MB): 0.60
Forward/backward pass size (MB): 51.63
Params size (MB): 26.50
Estimated Total Size (MB): 78.73
