Inception v3

Inception v3

  1. 2014年ImageNet竞赛的冠军Inception-v1,又名GoogLeNet。Inception v1的特点:模块增加网络的宽度。将模型的的输入经过几种卷积的计算,以concat方式连接。
  2. Inception v2,在v1版本上改进2个方向:
    1. 引入BN层
    2. 模型在计算过程中,会先对输入进行归一化
  3. Inception v3,在之前的基础上增加:
    1. 将大卷积分解成小卷积,使得在感受野不变的情况下,减少参数的计算量
    2. max pooling层在下采样会导致信息损失大,于是设计成计算输入A的卷积结果,计算输入A的pooling结果,并且将卷积的结果与池化的结果concat。这样减少计算量又减少信息损失。

本次复现的是Inception v3(pytorch)。原论文链接https://arxiv.org/abs/1512.00567

整个模型结构会基于基本卷积结构:卷积->BN->Relu

# 卷积->BN->Relu 代码如下
class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(Conv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding, bias=bias)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)

InceptionA 结构

Inception v3_第1张图片

代码实现:

class InceptionA(nn.Module):
    def __init__(self, in_channels, pool_channels):
        super(InceptionA, self).__init__()
        self.branch1x1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1)

        self.branch5x5_1 = Conv2d(in_channels=in_channels, out_channels=48, kernel_size=1)
        self.branch5x5_2 = Conv2d(in_channels=48, out_channels=64, kernel_size=5, padding=2)

        self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1)
        self.branch3x3_2 = Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1)
        self.branch3x3_3 = Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1)

        self.branch_pool = Conv2d(in_channels=in_channels, out_channels=pool_channels, kernel_size=1)

    def forward(self, x):
        x_1x1 = self.branch1x1(x)

        x_5x5 = self.branch5x5_1(x)
        x_5x5 = self.branch5x5_2(x_5x5)

        x_3x3 = self.branch3x3_1(x)
        x_3x3 = self.branch3x3_2(x_3x3)
        x_3x3 = self.branch3x3_3(x_3x3)

        x_branch = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        x_branch = self.branch_pool(x_branch)

        # 将结果按行拼接起来,1是行,0是列   输出channel为64 + 64 + 96 + pool_channels
        output = torch.cat([x_1x1, x_5x5, x_3x3, x_branch], 1)
        return output

InceptionB结构

Inception v3_第2张图片

代码实现:

class InceptionB(nn.Module):
    def __init__(self, in_channels):
        super(InceptionB, self).__init__()
        self.branch3x3 = Conv2d(in_channels=in_channels, out_channels=384, kernel_size=3, stride=2)

        self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1)
        self.branch3x3_2 = Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1)
        self.branch3x3_3 = Conv2d(in_channels=63, out_channels=96, kernel_size=3, stride=2)

    def forward(self, x):
        x_3x3 = self.branch3x3(x)

        x_3x3_1 = self.branch3x3_1(x)
        x_3x3_1 = self.branch3x3_2(x_3x3_1)
        x_3x3_1 = self.branch3x3_3(x_3x3_1)

        x_branch = F.max_pool2d(x, kernel_size=3, stride=2)
        return torch.cat([x_3x3, x_3x3_1, x_branch], 1)   # 384 + 96 + in_channels

InceptionC

Inception v3_第3张图片

代码实现:

class InceptionC(nn.Module):
    def __init__(self, in_channels, channels_7x7):
        super(InceptionC, self).__init__()
        self.branch_1x1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)

        self.branch_7x7_1 = Conv2d(in_channels=in_channels, out_channels=channels_7x7, kernel_size=1)
        self.branch_7x7_2 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(1, 7),
                                   padding=(0, 3))
        self.branch_7x7_3 = Conv2d(in_channels=channels_7x7, out_channels=192, kernel_size=(7, 1), padding=(3, 0))

        self.branch_7x7_4 = Conv2d(in_channels=in_channels, out_channels=channels_7x7, kernel_size=1)
        self.branch_7x7_5 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(7, 1), padding=(3, 0))
        self.branch_7x7_6 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(1, 7),
                                   padding=(0, 3))
        self.branch_7x7_7 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(7, 1),
                                   padding=(3, 0))
        self.branch_7x7_8 = Conv2d(in_channels=channels_7x7, out_channels=192, kernel_size=(1, 7),
                                   padding=(0, 3))

        self.branch_pool = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)

    def forward(self, x):
        x_1x1 = self.branch_1x1(x)

        x_7x7_1 = self.branch_7x7_1(x)
        x_7x7_1 = self.branch_7x7_2(x_7x7_1)
        x_7x7_1 = self.branch_7x7_3(x_7x7_1)

        x_7x7_2 = self.branch_7x7_4(x)
        x_7x7_2 = self.branch_7x7_5(x_7x7_2)
        x_7x7_2 = self.branch_7x7_6(x_7x7_2)
        x_7x7_2 = self.branch_7x7_7(x_7x7_2)
        x_7x7_2 = self.branch_7x7_8(x_7x7_2)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        return torch.cat([x_1x1, x_7x7_1, x_7x7_2, branch_pool], 1)   # 192  + 192  + 192 +192

InceptionD结构

Inception v3_第4张图片

代码实现:

class InceptionD(nn.Module):
    def __init__(self, in_channels):
        super(InceptionD, self).__init__()
        self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)
        self.branch3x3_2 = Conv2d(in_channels=192, out_channels=320, kernel_size=3, stride=2)

        self.branch7x7x3_1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)
        self.branch7x7x3_2 = Conv2d(in_channels=192, out_channels=192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = Conv2d(in_channels=192, out_channels=192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=2)

    def forward(self, x):
        x_3x3_1 = self.branch3x3_1(x)
        x_3x3_1 = self.branch3x3_2(x_3x3_1)

        x_7x7_1 = self.branch7x7x3_1(x)
        x_7x7_1 = self.branch7x7x3_2(x_7x7_1)
        x_7x7_1 = self.branch7x7x3_3(x_7x7_1)
        x_7x7_1 = self.branch7x7x3_4(x_7x7_1)

        branch = F.max_pool2d(x, kernel_size = 3, stride = 2)
        return torch.cat([x_3x3_1, x_7x7_1, branch], 1)   # 320 + 192 + 768 = 1280

InceptionE结构

Inception v3_第5张图片

代码实现:

class InceptionE(nn.Module):
    def __init__(self, in_channels):
        super(InceptionE, self).__init__()
        self.branch1x1 = Conv2d(in_channels=in_channels, out_channels=320, kernel_size=1)

        self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=384, kernel_size=1)
        self.branch3x3_2a = Conv2d(in_channels=384, out_channels=384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3dbl_1 = Conv2d(in_channels=in_channels, out_channels=448, kernel_size=1)
        self.branch3x3dbl_2 = Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = Conv2d(in_channels=384, out_channels=384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)

    def forward(self, x):
        x_1x1 = self.branch1x1(x)

        x_3x3 = self.branch3x3_1(x)
        x_3x3 = [
            self.branch3x3_2a(x_3x3),
            self.branch3x3_2b(x_3x3)
        ]
        x_3x3 = torch.cat(x_3x3, 1)

        x_3x3_dbl = self.branch3x3dbl_1(x)
        x_3x3_dbl = self.branch3x3dbl_2(x_3x3_dbl)
        x_3x3_dbl = [
            self.branch3x3dbl_3a(x_3x3_dbl),
            self.branch3x3dbl_3b(x_3x3_dbl)
        ]
        x_3x3_dbl = torch.cat(x_3x3_dbl, 1)

        x_branch = F.avg_pool2d(x, kerbel_size=3, stride = 1, padding = 1)
        x_branch = self.branch_pool(x_branch)

        return torch.cat([x_1x1, x_3x3, x_3x3_dbl, x_branch], 1)   # 320 + 384*2 + 384*2 + 192 = 2048

Aux结构:

Inception v3_第6张图片

代码:

class Inception_Aux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Inception_Aux, self).__init__()
        self.conv0 = Conv2d(in_channels=in_channels, out_channels=128, kernel_size=1)
        self.conv1 = Conv2d(in_channels=128, out_channels=768, kernel_size=5)
        self.conv1.stddev = 0.01
        self.fc = nn.Linear(768, num_classes)
        self.fc.stddev = 0.001

    def forward(self, x):
        # N x 768 x 17 x 17
        x = F.avg_pool2d(x, kernel_size = 5, stride = 3)
        # N x 768 x 5 x 5
        x = self.conv0(x)
        # N x 128 x 5 x 5
        x = self.conv1(x)
        # N x 768 x 1 x 1
        x = F.adaptive_max_pool2d(x, (1, 1))
        # N x 768 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 768
        x = self.fc(x)
        return x

主体结构:

Inception v3 整体模型结构

Inception v3_第7张图片

主体代码:

class Inception_v3(nn.Module):
    def __init__(self, num_classes, aux_logits):
        super(Inception_v3, self).__init__()
        # 3 conv -> 1 maxpool -> 2 conv -> 1 maxpool
        self.conv_1 = Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2)
        self.conv_2 = Conv2d(in_channels=32, out_channels=32, kernel_size=3)
        self.conv_3 = Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv_4 = Conv2d(in_channels=64, out_channels=80, kernel_size=1)
        self.conv_5 = Conv2d(in_channels=80, out_channels=192, kernel_size=3)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.inception_a_1 = InceptionA(in_channels=192, pool_channels=32)  # out_channels = 64 + 64 + 96 + 32 = 256
        self.inception_a_2 = InceptionA(in_channels=256, pool_channels=64)  # out_channels = 64 + 64 + 96 + 64 = 288
        self.inception_a_3 = InceptionA(in_channels=288, pool_channels=64)  # out_channels = 64 + 64 + 96 + 64 = 288
        self.inception_b_1 = InceptionB(288)    #384 + 96 + in_channels = 768
        self.inception_c_1 = InceptionC(in_channels=768, channels_7x7=128)   # 192 + 192 + 192 + 192 = 768
        self.inception_c_2 = InceptionC(in_channels=768, channels_7x7=160)
        self.inception_c_3 = InceptionC(in_channels=768, channels_7x7=160)
        self.inception_c_4 = InceptionC(in_channels=768, channels_7x7=192)
        self.aux_logits = aux_logits
        if self.aux_logits:
            self.auxlogits = Inception_Aux(768, num_classes=num_classes)
        self.inception_d = InceptionD(in_channels=768)  # 1280
        self.inception_e_1 = InceptionE(in_channels=1280)  # 2048
        self.inception_e_2 = InceptionE(in_channels=2048)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        # N x 3 x 299 x 299
        x = self.conv_1(x)
        # N x 32 x 149 x 149
        x = self.conv_2(x)
        # N x 32 x 147 x 147
        x = self.conv_3(x)
        # N x 64 x 147 x 147
        x = self.maxpool1(x)
        # N x 64 x 73 x 73
        x = self.conv_4(x)
        # N x 80 x 73 x 73
        x = self.conv_5(x)
        # N x 192 x 71 x 71
        x = self.maxpool2(x)
        # N x 192 x 35 x 35
        x = self.inception_a_1(x)
        # N x 256 x 35 x 35
        x = self.inception_a_2(x)
        # N x 288 x 35 x 35
        x = self.inception_a_3(x)
        # N x 288 x 35 x 35
        x = self.inception_b_1(x)
        # N x 768 x 17 x 17
        x = self.inception_c_1(x)
        # N x 768 x 17 x 17
        x = self.inception_c_2(x)
        # N x 768 x 17 x 17
        x = self.inception_c_3(x)
        # N x 768 x 17 x 17
        x = self.inception_c_4(x)
        # N x 768 x 17 x 17
        aux_defined = self.training and self.aux_logits
        if aux_defined:
            aux = self.auxlogits(x)
        else:
            aux = None
        # N x 768 x 17 x 17
        x = self.inception_d(x)
        # N x 1280 x 8 x 8
        x = self.inception_e_1(x)
        # N x 2048 x 8 x 8
        x = self.inception_e_2(x)
        # N x 2048 x 8 x 8
        # Adaptive average pooling
        x = self.avgpool(x)
        # N x 2048 x 1 x 1
        x = self.dropout(x)
        # N x 2048 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 2048
        x = self.fc(x)
        # N x 1000 (num_classes)
        return x, aux

全部代码:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False):
        super(Conv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding, bias=bias)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)


class InceptionA(nn.Module):
    def __init__(self, in_channels, pool_channels):
        super(InceptionA, self).__init__()
        self.branch1x1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1)

        self.branch5x5_1 = Conv2d(in_channels=in_channels, out_channels=48, kernel_size=1)
        self.branch5x5_2 = Conv2d(in_channels=48, out_channels=64, kernel_size=5, padding=2)

        self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1)
        self.branch3x3_2 = Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1)
        self.branch3x3_3 = Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1)

        self.branch_pool = Conv2d(in_channels=in_channels, out_channels=pool_channels, kernel_size=1)

    def forward(self, x):
        x_1x1 = self.branch1x1(x)

        x_5x5 = self.branch5x5_1(x)
        x_5x5 = self.branch5x5_2(x_5x5)

        x_3x3 = self.branch3x3_1(x)
        x_3x3 = self.branch3x3_2(x_3x3)
        x_3x3 = self.branch3x3_3(x_3x3)

        x_branch = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        x_branch = self.branch_pool(x_branch)

        # 将结果按行拼接起来,1是行,0是列   输出channel为64 + 64 + 96 + pool_channels
        output = torch.cat([x_1x1, x_5x5, x_3x3, x_branch], 1)
        return output


class InceptionB(nn.Module):
    def __init__(self, in_channels):
        super(InceptionB, self).__init__()
        self.branch3x3 = Conv2d(in_channels=in_channels, out_channels=384, kernel_size=3, stride=2)

        self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1)
        self.branch3x3_2 = Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1)
        self.branch3x3_3 = Conv2d(in_channels=63, out_channels=96, kernel_size=3, stride=2)

    def forward(self, x):
        x_3x3 = self.branch3x3(x)

        x_3x3_1 = self.branch3x3_1(x)
        x_3x3_1 = self.branch3x3_2(x_3x3_1)
        x_3x3_1 = self.branch3x3_3(x_3x3_1)

        x_branch = F.max_pool2d(x, kernel_size=3, stride=2)
        return torch.cat([x_3x3, x_3x3_1, x_branch], 1)   # 384 + 96 + in_channels


class InceptionC(nn.Module):
    def __init__(self, in_channels, channels_7x7):
        super(InceptionC, self).__init__()
        self.branch_1x1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)

        self.branch_7x7_1 = Conv2d(in_channels=in_channels, out_channels=channels_7x7, kernel_size=1)
        self.branch_7x7_2 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(1, 7),
                                   padding=(0, 3))
        self.branch_7x7_3 = Conv2d(in_channels=channels_7x7, out_channels=192, kernel_size=(7, 1), padding=(3, 0))

        self.branch_7x7_4 = Conv2d(in_channels=in_channels, out_channels=channels_7x7, kernel_size=1)
        self.branch_7x7_5 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(7, 1), padding=(3, 0))
        self.branch_7x7_6 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(1, 7),
                                   padding=(0, 3))
        self.branch_7x7_7 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(7, 1),
                                   padding=(3, 0))
        self.branch_7x7_8 = Conv2d(in_channels=channels_7x7, out_channels=192, kernel_size=(1, 7),
                                   padding=(0, 3))

        self.branch_pool = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)

    def forward(self, x):
        x_1x1 = self.branch_1x1(x)

        x_7x7_1 = self.branch_7x7_1(x)
        x_7x7_1 = self.branch_7x7_2(x_7x7_1)
        x_7x7_1 = self.branch_7x7_3(x_7x7_1)

        x_7x7_2 = self.branch_7x7_4(x)
        x_7x7_2 = self.branch_7x7_5(x_7x7_2)
        x_7x7_2 = self.branch_7x7_6(x_7x7_2)
        x_7x7_2 = self.branch_7x7_7(x_7x7_2)
        x_7x7_2 = self.branch_7x7_8(x_7x7_2)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        return torch.cat([x_1x1, x_7x7_1, x_7x7_2, branch_pool], 1)   # 192  + 192  + 192 +192

class Inception_Aux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Inception_Aux, self).__init__()
        self.conv0 = Conv2d(in_channels=in_channels, out_channels=128, kernel_size=1)
        self.conv1 = Conv2d(in_channels=128, out_channels=768, kernel_size=5)
        self.conv1.stddev = 0.01
        self.fc = nn.Linear(768, num_classes)
        self.fc.stddev = 0.001

    def forward(self, x):
        # N x 768 x 17 x 17
        x = F.avg_pool2d(x, kernel_size = 5, stride = 3)
        # N x 768 x 5 x 5
        x = self.conv0(x)
        # N x 128 x 5 x 5
        x = self.conv1(x)
        # N x 768 x 1 x 1
        x = F.adaptive_max_pool2d(x, (1, 1))
        # N x 768 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 768
        x = self.fc(x)
        return x

class InceptionD(nn.Module):
    def __init__(self, in_channels):
        super(InceptionD, self).__init__()
        self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)
        self.branch3x3_2 = Conv2d(in_channels=192, out_channels=320, kernel_size=3, stride=2)

        self.branch7x7x3_1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)
        self.branch7x7x3_2 = Conv2d(in_channels=192, out_channels=192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = Conv2d(in_channels=192, out_channels=192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=2)

    def forward(self, x):
        x_3x3_1 = self.branch3x3_1(x)
        x_3x3_1 = self.branch3x3_2(x_3x3_1)

        x_7x7_1 = self.branch7x7x3_1(x)
        x_7x7_1 = self.branch7x7x3_2(x_7x7_1)
        x_7x7_1 = self.branch7x7x3_3(x_7x7_1)
        x_7x7_1 = self.branch7x7x3_4(x_7x7_1)

        branch = F.max_pool2d(x, kernel_size = 3, stride = 2)
        return torch.cat([x_3x3_1, x_7x7_1, branch], 1)   # 320 + 192 + 768 = 1280

class InceptionE(nn.Module):
    def __init__(self, in_channels):
        super(InceptionE, self).__init__()
        self.branch1x1 = Conv2d(in_channels=in_channels, out_channels=320, kernel_size=1)

        self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=384, kernel_size=1)
        self.branch3x3_2a = Conv2d(in_channels=384, out_channels=384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3dbl_1 = Conv2d(in_channels=in_channels, out_channels=448, kernel_size=1)
        self.branch3x3dbl_2 = Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = Conv2d(in_channels=384, out_channels=384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1)

    def forward(self, x):
        x_1x1 = self.branch1x1(x)

        x_3x3 = self.branch3x3_1(x)
        x_3x3 = [
            self.branch3x3_2a(x_3x3),
            self.branch3x3_2b(x_3x3)
        ]
        x_3x3 = torch.cat(x_3x3, 1)

        x_3x3_dbl = self.branch3x3dbl_1(x)
        x_3x3_dbl = self.branch3x3dbl_2(x_3x3_dbl)
        x_3x3_dbl = [
            self.branch3x3dbl_3a(x_3x3_dbl),
            self.branch3x3dbl_3b(x_3x3_dbl)
        ]
        x_3x3_dbl = torch.cat(x_3x3_dbl, 1)

        x_branch = F.avg_pool2d(x, kerbel_size=3, stride = 1, padding = 1)
        x_branch = self.branch_pool(x_branch)

        return torch.cat([x_1x1, x_3x3, x_3x3_dbl, x_branch])   # 320 + 384*2 + 384*2 + 192 = 2048

class Inception_v3(nn.Module):
    def __init__(self, num_classes, aux_logits):
        super(Inception_v3, self).__init__()
        # 3 conv -> 1 maxpool -> 2 conv -> 1 maxpool
        self.conv_1 = Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2)
        self.conv_2 = Conv2d(in_channels=32, out_channels=32, kernel_size=3)
        self.conv_3 = Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv_4 = Conv2d(in_channels=64, out_channels=80, kernel_size=1)
        self.conv_5 = Conv2d(in_channels=80, out_channels=192, kernel_size=3)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.inception_a_1 = InceptionA(in_channels=192, pool_channels=32)  # out_channels = 64 + 64 + 96 + 32 = 256
        self.inception_a_2 = InceptionA(in_channels=256, pool_channels=64)  # out_channels = 64 + 64 + 96 + 64 = 288
        self.inception_a_3 = InceptionA(in_channels=288, pool_channels=64)  # out_channels = 64 + 64 + 96 + 64 = 288
        self.inception_b_1 = InceptionB(288)    #384 + 96 + in_channels = 768
        self.inception_c_1 = InceptionC(in_channels=768, channels_7x7=128)   # 192 + 192 + 192 + 192 = 768
        self.inception_c_2 = InceptionC(in_channels=768, channels_7x7=160)
        self.inception_c_3 = InceptionC(in_channels=768, channels_7x7=160)
        self.inception_c_4 = InceptionC(in_channels=768, channels_7x7=192)
        self.aux_logits = aux_logits
        if self.aux_logits:
            self.auxlogits = Inception_Aux(768, num_classes=num_classes)
        self.inception_d = InceptionD(in_channels=768)  # 1280
        self.inception_e_1 = InceptionE(in_channels=1280)  # 2048
        self.inception_e_2 = InceptionE(in_channels=2048)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        # N x 3 x 299 x 299
        x = self.conv_1(x)
        # N x 32 x 149 x 149
        x = self.conv_2(x)
        # N x 32 x 147 x 147
        x = self.conv_3(x)
        # N x 64 x 147 x 147
        x = self.maxpool1(x)
        # N x 64 x 73 x 73
        x = self.conv_4(x)
        # N x 80 x 73 x 73
        x = self.conv_5(x)
        # N x 192 x 71 x 71
        x = self.maxpool2(x)
        # N x 192 x 35 x 35
        x = self.inception_a_1(x)
        # N x 256 x 35 x 35
        x = self.inception_a_2(x)
        # N x 288 x 35 x 35
        x = self.inception_a_3(x)
        # N x 288 x 35 x 35
        x = self.inception_b_1(x)
        # N x 768 x 17 x 17
        x = self.inception_c_1(x)
        # N x 768 x 17 x 17
        x = self.inception_c_2(x)
        # N x 768 x 17 x 17
        x = self.inception_c_3(x)
        # N x 768 x 17 x 17
        x = self.inception_c_4(x)
        # N x 768 x 17 x 17
        aux_defined = self.training and self.aux_logits
        if aux_defined:
            aux = self.auxlogits(x)
        else:
            aux = None
        # N x 768 x 17 x 17
        x = self.inception_d(x)
        # N x 1280 x 8 x 8
        x = self.inception_e_1(x)
        # N x 2048 x 8 x 8
        x = self.inception_e_2(x)
        # N x 2048 x 8 x 8
        # Adaptive average pooling
        x = self.avgpool(x)
        # N x 2048 x 1 x 1
        x = self.dropout(x)
        # N x 2048 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 2048
        x = self.fc(x)
        # N x 1000 (num_classes)
        return x, aux

你可能感兴趣的:(#,图像分类,深度学习,pytorch,计算机视觉)