Pytorch MobileNet代码

Pytorch MobileNet代码

  • 论文中定义的网络结构
  • 代码部分

论文中定义的网络结构

Pytorch MobileNet代码_第1张图片
Fig.3左边的结构对应代码中的conv_bn,表示带有BN和ReLU的标准卷积层。
Fig.3右边的结构对应代码中的conv_dw,表示带有逐深度卷积和逐点卷积层的深度可分离卷积层。

Pytorch MobileNet代码_第2张图片
上图是MobileNet的网络主体结构。

代码部分

class MobileNet(nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),  # groups=inp:逐深度卷积
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),
    
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),   				# padding=0
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential( 					# inp, oup, stride
            conv_bn(  3,  32, 2),   					# 论文中的Table1的Conv dw + Conv是一个深度可分离卷积
            conv_dw( 32,  64, 1),   					# stride是逐深度卷积的步长(第一个和最后一个Conv除外)
            conv_dw( 64, 128, 2),   					# inp, oup是Conv中的尺寸
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AvgPool2d(7),
        )
        self.fc = nn.Linear(1024, 1000)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

你可能感兴趣的:(Pytorch)