Inception v3
- 2014年ImageNet竞赛的冠军Inception-v1,又名GoogLeNet。Inception v1的特点:模块增加网络的宽度。将模型的的输入经过几种卷积的计算,以concat方式连接。
- Inception v2,在v1版本上改进2个方向:
- 引入BN层
- 模型在计算过程中,会先对输入进行归一化
- Inception v3,在之前的基础上增加:
- 将大卷积分解成小卷积,使得在感受野不变的情况下,减少参数的计算量
- max pooling层在下采样会导致信息损失大,于是设计成计算输入A的卷积结果,计算输入A的pooling结果,并且将卷积的结果与池化的结果concat。这样减少计算量又减少信息损失。
本次复现的是Inception v3(pytorch)。原论文链接https://arxiv.org/abs/1512.00567
整个模型结构会基于基本卷积结构:卷积->BN->Relu
# 卷积->BN->Relu 代码如下 class Conv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False): super(Conv2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) return self.relu(x)
InceptionA 结构
代码实现:
class InceptionA(nn.Module): def __init__(self, in_channels, pool_channels): super(InceptionA, self).__init__() self.branch1x1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1) self.branch5x5_1 = Conv2d(in_channels=in_channels, out_channels=48, kernel_size=1) self.branch5x5_2 = Conv2d(in_channels=48, out_channels=64, kernel_size=5, padding=2) self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1) self.branch3x3_2 = Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1) self.branch3x3_3 = Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1) self.branch_pool = Conv2d(in_channels=in_channels, out_channels=pool_channels, kernel_size=1) def forward(self, x): x_1x1 = self.branch1x1(x) x_5x5 = self.branch5x5_1(x) x_5x5 = self.branch5x5_2(x_5x5) x_3x3 = self.branch3x3_1(x) x_3x3 = self.branch3x3_2(x_3x3) x_3x3 = self.branch3x3_3(x_3x3) x_branch = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) x_branch = self.branch_pool(x_branch) # 将结果按行拼接起来,1是行,0是列 输出channel为64 + 64 + 96 + pool_channels output = torch.cat([x_1x1, x_5x5, x_3x3, x_branch], 1) return output
InceptionB结构
代码实现:
class InceptionB(nn.Module): def __init__(self, in_channels): super(InceptionB, self).__init__() self.branch3x3 = Conv2d(in_channels=in_channels, out_channels=384, kernel_size=3, stride=2) self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1) self.branch3x3_2 = Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1) self.branch3x3_3 = Conv2d(in_channels=63, out_channels=96, kernel_size=3, stride=2) def forward(self, x): x_3x3 = self.branch3x3(x) x_3x3_1 = self.branch3x3_1(x) x_3x3_1 = self.branch3x3_2(x_3x3_1) x_3x3_1 = self.branch3x3_3(x_3x3_1) x_branch = F.max_pool2d(x, kernel_size=3, stride=2) return torch.cat([x_3x3, x_3x3_1, x_branch], 1) # 384 + 96 + in_channels
InceptionC
代码实现:
class InceptionC(nn.Module): def __init__(self, in_channels, channels_7x7): super(InceptionC, self).__init__() self.branch_1x1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) self.branch_7x7_1 = Conv2d(in_channels=in_channels, out_channels=channels_7x7, kernel_size=1) self.branch_7x7_2 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(1, 7), padding=(0, 3)) self.branch_7x7_3 = Conv2d(in_channels=channels_7x7, out_channels=192, kernel_size=(7, 1), padding=(3, 0)) self.branch_7x7_4 = Conv2d(in_channels=in_channels, out_channels=channels_7x7, kernel_size=1) self.branch_7x7_5 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(7, 1), padding=(3, 0)) self.branch_7x7_6 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(1, 7), padding=(0, 3)) self.branch_7x7_7 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(7, 1), padding=(3, 0)) self.branch_7x7_8 = Conv2d(in_channels=channels_7x7, out_channels=192, kernel_size=(1, 7), padding=(0, 3)) self.branch_pool = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) def forward(self, x): x_1x1 = self.branch_1x1(x) x_7x7_1 = self.branch_7x7_1(x) x_7x7_1 = self.branch_7x7_2(x_7x7_1) x_7x7_1 = self.branch_7x7_3(x_7x7_1) x_7x7_2 = self.branch_7x7_4(x) x_7x7_2 = self.branch_7x7_5(x_7x7_2) x_7x7_2 = self.branch_7x7_6(x_7x7_2) x_7x7_2 = self.branch_7x7_7(x_7x7_2) x_7x7_2 = self.branch_7x7_8(x_7x7_2) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) return torch.cat([x_1x1, x_7x7_1, x_7x7_2, branch_pool], 1) # 192 + 192 + 192 +192
InceptionD结构
代码实现:
class InceptionD(nn.Module): def __init__(self, in_channels): super(InceptionD, self).__init__() self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) self.branch3x3_2 = Conv2d(in_channels=192, out_channels=320, kernel_size=3, stride=2) self.branch7x7x3_1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) self.branch7x7x3_2 = Conv2d(in_channels=192, out_channels=192, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7x3_3 = Conv2d(in_channels=192, out_channels=192, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7x3_4 = Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=2) def forward(self, x): x_3x3_1 = self.branch3x3_1(x) x_3x3_1 = self.branch3x3_2(x_3x3_1) x_7x7_1 = self.branch7x7x3_1(x) x_7x7_1 = self.branch7x7x3_2(x_7x7_1) x_7x7_1 = self.branch7x7x3_3(x_7x7_1) x_7x7_1 = self.branch7x7x3_4(x_7x7_1) branch = F.max_pool2d(x, kernel_size = 3, stride = 2) return torch.cat([x_3x3_1, x_7x7_1, branch], 1) # 320 + 192 + 768 = 1280
InceptionE结构
代码实现:
class InceptionE(nn.Module): def __init__(self, in_channels): super(InceptionE, self).__init__() self.branch1x1 = Conv2d(in_channels=in_channels, out_channels=320, kernel_size=1) self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=384, kernel_size=1) self.branch3x3_2a = Conv2d(in_channels=384, out_channels=384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3_2b = Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 1), padding=(1, 0)) self.branch3x3dbl_1 = Conv2d(in_channels=in_channels, out_channels=448, kernel_size=1) self.branch3x3dbl_2 = Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1) self.branch3x3dbl_3a = Conv2d(in_channels=384, out_channels=384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3dbl_3b = Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 1), padding=(1, 0)) self.branch_pool = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) def forward(self, x): x_1x1 = self.branch1x1(x) x_3x3 = self.branch3x3_1(x) x_3x3 = [ self.branch3x3_2a(x_3x3), self.branch3x3_2b(x_3x3) ] x_3x3 = torch.cat(x_3x3, 1) x_3x3_dbl = self.branch3x3dbl_1(x) x_3x3_dbl = self.branch3x3dbl_2(x_3x3_dbl) x_3x3_dbl = [ self.branch3x3dbl_3a(x_3x3_dbl), self.branch3x3dbl_3b(x_3x3_dbl) ] x_3x3_dbl = torch.cat(x_3x3_dbl, 1) x_branch = F.avg_pool2d(x, kerbel_size=3, stride = 1, padding = 1) x_branch = self.branch_pool(x_branch) return torch.cat([x_1x1, x_3x3, x_3x3_dbl, x_branch], 1) # 320 + 384*2 + 384*2 + 192 = 2048
Aux结构:
代码:
class Inception_Aux(nn.Module): def __init__(self, in_channels, num_classes): super(Inception_Aux, self).__init__() self.conv0 = Conv2d(in_channels=in_channels, out_channels=128, kernel_size=1) self.conv1 = Conv2d(in_channels=128, out_channels=768, kernel_size=5) self.conv1.stddev = 0.01 self.fc = nn.Linear(768, num_classes) self.fc.stddev = 0.001 def forward(self, x): # N x 768 x 17 x 17 x = F.avg_pool2d(x, kernel_size = 5, stride = 3) # N x 768 x 5 x 5 x = self.conv0(x) # N x 128 x 5 x 5 x = self.conv1(x) # N x 768 x 1 x 1 x = F.adaptive_max_pool2d(x, (1, 1)) # N x 768 x 1 x 1 x = torch.flatten(x, 1) # N x 768 x = self.fc(x) return x
主体结构:
Inception v3 整体模型结构
主体代码:
class Inception_v3(nn.Module): def __init__(self, num_classes, aux_logits): super(Inception_v3, self).__init__() # 3 conv -> 1 maxpool -> 2 conv -> 1 maxpool self.conv_1 = Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2) self.conv_2 = Conv2d(in_channels=32, out_channels=32, kernel_size=3) self.conv_3 = Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) self.conv_4 = Conv2d(in_channels=64, out_channels=80, kernel_size=1) self.conv_5 = Conv2d(in_channels=80, out_channels=192, kernel_size=3) self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) self.inception_a_1 = InceptionA(in_channels=192, pool_channels=32) # out_channels = 64 + 64 + 96 + 32 = 256 self.inception_a_2 = InceptionA(in_channels=256, pool_channels=64) # out_channels = 64 + 64 + 96 + 64 = 288 self.inception_a_3 = InceptionA(in_channels=288, pool_channels=64) # out_channels = 64 + 64 + 96 + 64 = 288 self.inception_b_1 = InceptionB(288) #384 + 96 + in_channels = 768 self.inception_c_1 = InceptionC(in_channels=768, channels_7x7=128) # 192 + 192 + 192 + 192 = 768 self.inception_c_2 = InceptionC(in_channels=768, channels_7x7=160) self.inception_c_3 = InceptionC(in_channels=768, channels_7x7=160) self.inception_c_4 = InceptionC(in_channels=768, channels_7x7=192) self.aux_logits = aux_logits if self.aux_logits: self.auxlogits = Inception_Aux(768, num_classes=num_classes) self.inception_d = InceptionD(in_channels=768) # 1280 self.inception_e_1 = InceptionE(in_channels=1280) # 2048 self.inception_e_2 = InceptionE(in_channels=2048) self.avgpool = nn.AdaptiveAvgPool2d((1,1)) self.dropout = nn.Dropout() self.fc = nn.Linear(2048, num_classes) def forward(self, x): # N x 3 x 299 x 299 x = self.conv_1(x) # N x 32 x 149 x 149 x = self.conv_2(x) # N x 32 x 147 x 147 x = self.conv_3(x) # N x 64 x 147 x 147 x = self.maxpool1(x) # N x 64 x 73 x 73 x = self.conv_4(x) # N x 80 x 73 x 73 x = self.conv_5(x) # N x 192 x 71 x 71 x = self.maxpool2(x) # N x 192 x 35 x 35 x = self.inception_a_1(x) # N x 256 x 35 x 35 x = self.inception_a_2(x) # N x 288 x 35 x 35 x = self.inception_a_3(x) # N x 288 x 35 x 35 x = self.inception_b_1(x) # N x 768 x 17 x 17 x = self.inception_c_1(x) # N x 768 x 17 x 17 x = self.inception_c_2(x) # N x 768 x 17 x 17 x = self.inception_c_3(x) # N x 768 x 17 x 17 x = self.inception_c_4(x) # N x 768 x 17 x 17 aux_defined = self.training and self.aux_logits if aux_defined: aux = self.auxlogits(x) else: aux = None # N x 768 x 17 x 17 x = self.inception_d(x) # N x 1280 x 8 x 8 x = self.inception_e_1(x) # N x 2048 x 8 x 8 x = self.inception_e_2(x) # N x 2048 x 8 x 8 # Adaptive average pooling x = self.avgpool(x) # N x 2048 x 1 x 1 x = self.dropout(x) # N x 2048 x 1 x 1 x = torch.flatten(x, 1) # N x 2048 x = self.fc(x) # N x 1000 (num_classes) return x, aux
全部代码:
import torch import torch.nn as nn import torch.nn.functional as F class Conv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False): super(Conv2d, self).__init__() self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) return self.relu(x) class InceptionA(nn.Module): def __init__(self, in_channels, pool_channels): super(InceptionA, self).__init__() self.branch1x1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1) self.branch5x5_1 = Conv2d(in_channels=in_channels, out_channels=48, kernel_size=1) self.branch5x5_2 = Conv2d(in_channels=48, out_channels=64, kernel_size=5, padding=2) self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1) self.branch3x3_2 = Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1) self.branch3x3_3 = Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1) self.branch_pool = Conv2d(in_channels=in_channels, out_channels=pool_channels, kernel_size=1) def forward(self, x): x_1x1 = self.branch1x1(x) x_5x5 = self.branch5x5_1(x) x_5x5 = self.branch5x5_2(x_5x5) x_3x3 = self.branch3x3_1(x) x_3x3 = self.branch3x3_2(x_3x3) x_3x3 = self.branch3x3_3(x_3x3) x_branch = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) x_branch = self.branch_pool(x_branch) # 将结果按行拼接起来,1是行,0是列 输出channel为64 + 64 + 96 + pool_channels output = torch.cat([x_1x1, x_5x5, x_3x3, x_branch], 1) return output class InceptionB(nn.Module): def __init__(self, in_channels): super(InceptionB, self).__init__() self.branch3x3 = Conv2d(in_channels=in_channels, out_channels=384, kernel_size=3, stride=2) self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=64, kernel_size=1) self.branch3x3_2 = Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1) self.branch3x3_3 = Conv2d(in_channels=63, out_channels=96, kernel_size=3, stride=2) def forward(self, x): x_3x3 = self.branch3x3(x) x_3x3_1 = self.branch3x3_1(x) x_3x3_1 = self.branch3x3_2(x_3x3_1) x_3x3_1 = self.branch3x3_3(x_3x3_1) x_branch = F.max_pool2d(x, kernel_size=3, stride=2) return torch.cat([x_3x3, x_3x3_1, x_branch], 1) # 384 + 96 + in_channels class InceptionC(nn.Module): def __init__(self, in_channels, channels_7x7): super(InceptionC, self).__init__() self.branch_1x1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) self.branch_7x7_1 = Conv2d(in_channels=in_channels, out_channels=channels_7x7, kernel_size=1) self.branch_7x7_2 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(1, 7), padding=(0, 3)) self.branch_7x7_3 = Conv2d(in_channels=channels_7x7, out_channels=192, kernel_size=(7, 1), padding=(3, 0)) self.branch_7x7_4 = Conv2d(in_channels=in_channels, out_channels=channels_7x7, kernel_size=1) self.branch_7x7_5 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(7, 1), padding=(3, 0)) self.branch_7x7_6 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(1, 7), padding=(0, 3)) self.branch_7x7_7 = Conv2d(in_channels=channels_7x7, out_channels=channels_7x7, kernel_size=(7, 1), padding=(3, 0)) self.branch_7x7_8 = Conv2d(in_channels=channels_7x7, out_channels=192, kernel_size=(1, 7), padding=(0, 3)) self.branch_pool = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) def forward(self, x): x_1x1 = self.branch_1x1(x) x_7x7_1 = self.branch_7x7_1(x) x_7x7_1 = self.branch_7x7_2(x_7x7_1) x_7x7_1 = self.branch_7x7_3(x_7x7_1) x_7x7_2 = self.branch_7x7_4(x) x_7x7_2 = self.branch_7x7_5(x_7x7_2) x_7x7_2 = self.branch_7x7_6(x_7x7_2) x_7x7_2 = self.branch_7x7_7(x_7x7_2) x_7x7_2 = self.branch_7x7_8(x_7x7_2) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) return torch.cat([x_1x1, x_7x7_1, x_7x7_2, branch_pool], 1) # 192 + 192 + 192 +192 class Inception_Aux(nn.Module): def __init__(self, in_channels, num_classes): super(Inception_Aux, self).__init__() self.conv0 = Conv2d(in_channels=in_channels, out_channels=128, kernel_size=1) self.conv1 = Conv2d(in_channels=128, out_channels=768, kernel_size=5) self.conv1.stddev = 0.01 self.fc = nn.Linear(768, num_classes) self.fc.stddev = 0.001 def forward(self, x): # N x 768 x 17 x 17 x = F.avg_pool2d(x, kernel_size = 5, stride = 3) # N x 768 x 5 x 5 x = self.conv0(x) # N x 128 x 5 x 5 x = self.conv1(x) # N x 768 x 1 x 1 x = F.adaptive_max_pool2d(x, (1, 1)) # N x 768 x 1 x 1 x = torch.flatten(x, 1) # N x 768 x = self.fc(x) return x class InceptionD(nn.Module): def __init__(self, in_channels): super(InceptionD, self).__init__() self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) self.branch3x3_2 = Conv2d(in_channels=192, out_channels=320, kernel_size=3, stride=2) self.branch7x7x3_1 = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) self.branch7x7x3_2 = Conv2d(in_channels=192, out_channels=192, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7x3_3 = Conv2d(in_channels=192, out_channels=192, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7x3_4 = Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=2) def forward(self, x): x_3x3_1 = self.branch3x3_1(x) x_3x3_1 = self.branch3x3_2(x_3x3_1) x_7x7_1 = self.branch7x7x3_1(x) x_7x7_1 = self.branch7x7x3_2(x_7x7_1) x_7x7_1 = self.branch7x7x3_3(x_7x7_1) x_7x7_1 = self.branch7x7x3_4(x_7x7_1) branch = F.max_pool2d(x, kernel_size = 3, stride = 2) return torch.cat([x_3x3_1, x_7x7_1, branch], 1) # 320 + 192 + 768 = 1280 class InceptionE(nn.Module): def __init__(self, in_channels): super(InceptionE, self).__init__() self.branch1x1 = Conv2d(in_channels=in_channels, out_channels=320, kernel_size=1) self.branch3x3_1 = Conv2d(in_channels=in_channels, out_channels=384, kernel_size=1) self.branch3x3_2a = Conv2d(in_channels=384, out_channels=384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3_2b = Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 1), padding=(1, 0)) self.branch3x3dbl_1 = Conv2d(in_channels=in_channels, out_channels=448, kernel_size=1) self.branch3x3dbl_2 = Conv2d(in_channels=448, out_channels=384, kernel_size=3, padding=1) self.branch3x3dbl_3a = Conv2d(in_channels=384, out_channels=384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3dbl_3b = Conv2d(in_channels=384, out_channels=384, kernel_size=(3, 1), padding=(1, 0)) self.branch_pool = Conv2d(in_channels=in_channels, out_channels=192, kernel_size=1) def forward(self, x): x_1x1 = self.branch1x1(x) x_3x3 = self.branch3x3_1(x) x_3x3 = [ self.branch3x3_2a(x_3x3), self.branch3x3_2b(x_3x3) ] x_3x3 = torch.cat(x_3x3, 1) x_3x3_dbl = self.branch3x3dbl_1(x) x_3x3_dbl = self.branch3x3dbl_2(x_3x3_dbl) x_3x3_dbl = [ self.branch3x3dbl_3a(x_3x3_dbl), self.branch3x3dbl_3b(x_3x3_dbl) ] x_3x3_dbl = torch.cat(x_3x3_dbl, 1) x_branch = F.avg_pool2d(x, kerbel_size=3, stride = 1, padding = 1) x_branch = self.branch_pool(x_branch) return torch.cat([x_1x1, x_3x3, x_3x3_dbl, x_branch]) # 320 + 384*2 + 384*2 + 192 = 2048 class Inception_v3(nn.Module): def __init__(self, num_classes, aux_logits): super(Inception_v3, self).__init__() # 3 conv -> 1 maxpool -> 2 conv -> 1 maxpool self.conv_1 = Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2) self.conv_2 = Conv2d(in_channels=32, out_channels=32, kernel_size=3) self.conv_3 = Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2) self.conv_4 = Conv2d(in_channels=64, out_channels=80, kernel_size=1) self.conv_5 = Conv2d(in_channels=80, out_channels=192, kernel_size=3) self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2) self.inception_a_1 = InceptionA(in_channels=192, pool_channels=32) # out_channels = 64 + 64 + 96 + 32 = 256 self.inception_a_2 = InceptionA(in_channels=256, pool_channels=64) # out_channels = 64 + 64 + 96 + 64 = 288 self.inception_a_3 = InceptionA(in_channels=288, pool_channels=64) # out_channels = 64 + 64 + 96 + 64 = 288 self.inception_b_1 = InceptionB(288) #384 + 96 + in_channels = 768 self.inception_c_1 = InceptionC(in_channels=768, channels_7x7=128) # 192 + 192 + 192 + 192 = 768 self.inception_c_2 = InceptionC(in_channels=768, channels_7x7=160) self.inception_c_3 = InceptionC(in_channels=768, channels_7x7=160) self.inception_c_4 = InceptionC(in_channels=768, channels_7x7=192) self.aux_logits = aux_logits if self.aux_logits: self.auxlogits = Inception_Aux(768, num_classes=num_classes) self.inception_d = InceptionD(in_channels=768) # 1280 self.inception_e_1 = InceptionE(in_channels=1280) # 2048 self.inception_e_2 = InceptionE(in_channels=2048) self.avgpool = nn.AdaptiveAvgPool2d((1,1)) self.dropout = nn.Dropout() self.fc = nn.Linear(2048, num_classes) def forward(self, x): # N x 3 x 299 x 299 x = self.conv_1(x) # N x 32 x 149 x 149 x = self.conv_2(x) # N x 32 x 147 x 147 x = self.conv_3(x) # N x 64 x 147 x 147 x = self.maxpool1(x) # N x 64 x 73 x 73 x = self.conv_4(x) # N x 80 x 73 x 73 x = self.conv_5(x) # N x 192 x 71 x 71 x = self.maxpool2(x) # N x 192 x 35 x 35 x = self.inception_a_1(x) # N x 256 x 35 x 35 x = self.inception_a_2(x) # N x 288 x 35 x 35 x = self.inception_a_3(x) # N x 288 x 35 x 35 x = self.inception_b_1(x) # N x 768 x 17 x 17 x = self.inception_c_1(x) # N x 768 x 17 x 17 x = self.inception_c_2(x) # N x 768 x 17 x 17 x = self.inception_c_3(x) # N x 768 x 17 x 17 x = self.inception_c_4(x) # N x 768 x 17 x 17 aux_defined = self.training and self.aux_logits if aux_defined: aux = self.auxlogits(x) else: aux = None # N x 768 x 17 x 17 x = self.inception_d(x) # N x 1280 x 8 x 8 x = self.inception_e_1(x) # N x 2048 x 8 x 8 x = self.inception_e_2(x) # N x 2048 x 8 x 8 # Adaptive average pooling x = self.avgpool(x) # N x 2048 x 1 x 1 x = self.dropout(x) # N x 2048 x 1 x 1 x = torch.flatten(x, 1) # N x 2048 x = self.fc(x) # N x 1000 (num_classes) return x, aux