PyTorch 实现MobileNet_v3在CIFAR10上图像分类

目录

一、前言

二、网络结构

   (一)hard-swish激活函数

   (二)bneck结构

   (三)网络结构​

三、参数量

四、代码

五、训练结果

六、完整代码


一、前言

        MobileNet_v3是在MobileNet_v2以及MnasNet基础上改进的。

        先简要介绍下MnasNet网络,MnasNet是在MobileNet_v2与MobileNet_v3之间被提出的,主要基于NAS(神经架构搜索)来寻找不同的block。之前大多数NAS搜索的是一个cell,之后将cell进行反复堆叠,而这种网络往往是很受限的(主要表现在准确率以及延迟上)。MnasNet通过寻找不同的block增加了网络的多样性,此外为了更好的反应模型部署到手机设备上的延时,该网络引入了多目标优化,兼顾手机的延时以及模型准确率。

        MobileNet_v3首先基于NAS搜索出模型,之后更改了模型第一层(原本第一个卷积核个数为32,现在更改为16),将最后一层(分类层)做了修改,同时论文中提出了hard-swish非线性激活函数(论文中作者通过实验证明,该激活函数只有在较深层时候效果会好,因此从网络结构中可以看到前面几层用的还是relu激活函数)。

        MobileNet_v3仍然采用了MobileNet_v2中的倒残差结构(Inverted Residuals),同时引入了MnasNet中的注意力机制,这也就是论文中的bneck,论文中提出了两种模型,分别为MobileNetV3-Small以及MobileNetV3-large,本文代码实现的是MobileNetV3-large。

二、网络结构

   (一)hard-swish激活函数

        hard-swish是对swish激活函数做了优化,极大减少了计算的复杂度。

        swish激活函数的表达式为x*sigmod(x),hard-swish的表达式为x*\frac{ReLu6(x+3))}{6},可以很明显看出relu激活函数的计算复杂度要比sigmid的小很多,通过下图可以看出hard-swish的效果几乎与swish相同。

PyTorch 实现MobileNet_v3在CIFAR10上图像分类_第1张图片

   (二)bneck结构

        PyTorch 实现MobileNet_v3在CIFAR10上图像分类_第2张图片

        首先通过1*1卷积升维,升维后的feature map进行dw卷积,之后通过SE模块(相当于计算出升维后feature map各个维度乘以相应的占比因子),最后通过1*1卷积降维,可以看出此部分相当于在MobileNet_v2中间加了SE模块,shortcut部分与MobileNet_v2相同,只有当输入channel=输出channel,并且stride等于1时候才有残差连接。

   (三)网络结构PyTorch 实现MobileNet_v3在CIFAR10上图像分类_第3张图片

        这是论文中给出的网络结构,值得注意的是第一个卷积核的个数为16,并且采用了HS激活函数;表中exp_size代表benck中第一部分升维后的channel,SE代表是否使用SE模块,NL表示激活函数的类型,HS代表hard-swish激活函数,RE代表ReLU激活函数,s代表步长。

PyTorch 实现MobileNet_v3在CIFAR10上图像分类_第4张图片

        修改前的last stage与修改后的last stage(图来自原论文),明显看出计算量减少了很多,同时延时降低了2毫秒,另外需要注意的是,最后用于分类的部分(NBN)代表不采用BatchNormalization

三、参数量

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 16, 112, 112]             432
       BatchNorm2d-2         [-1, 16, 112, 112]              32
         Hardswish-3         [-1, 16, 112, 112]               0
          baseConv-4         [-1, 16, 112, 112]               0
            Conv2d-5         [-1, 16, 112, 112]             144
       BatchNorm2d-6         [-1, 16, 112, 112]              32
         Hardswish-7         [-1, 16, 112, 112]               0
          baseConv-8         [-1, 16, 112, 112]               0
            Conv2d-9         [-1, 16, 112, 112]             256
      BatchNorm2d-10         [-1, 16, 112, 112]              32
         Identity-11         [-1, 16, 112, 112]               0
         baseConv-12         [-1, 16, 112, 112]               0
      bneckModule-13         [-1, 16, 112, 112]               0
           Conv2d-14         [-1, 64, 112, 112]           1,024
      BatchNorm2d-15         [-1, 64, 112, 112]             128
            ReLU6-16         [-1, 64, 112, 112]               0
         baseConv-17         [-1, 64, 112, 112]               0
           Conv2d-18           [-1, 64, 56, 56]             576
      BatchNorm2d-19           [-1, 64, 56, 56]             128
            ReLU6-20           [-1, 64, 56, 56]               0
         baseConv-21           [-1, 64, 56, 56]               0
           Conv2d-22           [-1, 24, 56, 56]           1,536
      BatchNorm2d-23           [-1, 24, 56, 56]              48
         Identity-24           [-1, 24, 56, 56]               0
         baseConv-25           [-1, 24, 56, 56]               0
      bneckModule-26           [-1, 24, 56, 56]               0
           Conv2d-27           [-1, 72, 56, 56]           1,728
      BatchNorm2d-28           [-1, 72, 56, 56]             144
            ReLU6-29           [-1, 72, 56, 56]               0
         baseConv-30           [-1, 72, 56, 56]               0
           Conv2d-31           [-1, 72, 56, 56]             648
      BatchNorm2d-32           [-1, 72, 56, 56]             144
            ReLU6-33           [-1, 72, 56, 56]               0
         baseConv-34           [-1, 72, 56, 56]               0
           Conv2d-35           [-1, 24, 56, 56]           1,728
      BatchNorm2d-36           [-1, 24, 56, 56]              48
         Identity-37           [-1, 24, 56, 56]               0
         baseConv-38           [-1, 24, 56, 56]               0
      bneckModule-39           [-1, 24, 56, 56]               0
           Conv2d-40           [-1, 72, 56, 56]           1,728
      BatchNorm2d-41           [-1, 72, 56, 56]             144
            ReLU6-42           [-1, 72, 56, 56]               0
         baseConv-43           [-1, 72, 56, 56]               0
           Conv2d-44           [-1, 72, 28, 28]           1,800
      BatchNorm2d-45           [-1, 72, 28, 28]             144
            ReLU6-46           [-1, 72, 28, 28]               0
         baseConv-47           [-1, 72, 28, 28]               0
AdaptiveAvgPool2d-48             [-1, 72, 1, 1]               0
           Conv2d-49             [-1, 18, 1, 1]           1,314
            ReLU6-50             [-1, 18, 1, 1]               0
           Conv2d-51             [-1, 72, 1, 1]           1,368
        Hardswish-52             [-1, 72, 1, 1]               0
         SEModule-53           [-1, 72, 28, 28]               0
           Conv2d-54           [-1, 40, 28, 28]           2,880
      BatchNorm2d-55           [-1, 40, 28, 28]              80
         Identity-56           [-1, 40, 28, 28]               0
         baseConv-57           [-1, 40, 28, 28]               0
      bneckModule-58           [-1, 40, 28, 28]               0
           Conv2d-59          [-1, 120, 28, 28]           4,800
      BatchNorm2d-60          [-1, 120, 28, 28]             240
            ReLU6-61          [-1, 120, 28, 28]               0
         baseConv-62          [-1, 120, 28, 28]               0
           Conv2d-63          [-1, 120, 28, 28]           3,000
      BatchNorm2d-64          [-1, 120, 28, 28]             240
            ReLU6-65          [-1, 120, 28, 28]               0
         baseConv-66          [-1, 120, 28, 28]               0
AdaptiveAvgPool2d-67            [-1, 120, 1, 1]               0
           Conv2d-68             [-1, 30, 1, 1]           3,630
            ReLU6-69             [-1, 30, 1, 1]               0
           Conv2d-70            [-1, 120, 1, 1]           3,720
        Hardswish-71            [-1, 120, 1, 1]               0
         SEModule-72          [-1, 120, 28, 28]               0
           Conv2d-73           [-1, 40, 28, 28]           4,800
      BatchNorm2d-74           [-1, 40, 28, 28]              80
         Identity-75           [-1, 40, 28, 28]               0
         baseConv-76           [-1, 40, 28, 28]               0
      bneckModule-77           [-1, 40, 28, 28]               0
           Conv2d-78          [-1, 120, 28, 28]           4,800
      BatchNorm2d-79          [-1, 120, 28, 28]             240
            ReLU6-80          [-1, 120, 28, 28]               0
         baseConv-81          [-1, 120, 28, 28]               0
           Conv2d-82          [-1, 120, 28, 28]           3,000
      BatchNorm2d-83          [-1, 120, 28, 28]             240
            ReLU6-84          [-1, 120, 28, 28]               0
         baseConv-85          [-1, 120, 28, 28]               0
AdaptiveAvgPool2d-86            [-1, 120, 1, 1]               0
           Conv2d-87             [-1, 30, 1, 1]           3,630
            ReLU6-88             [-1, 30, 1, 1]               0
           Conv2d-89            [-1, 120, 1, 1]           3,720
        Hardswish-90            [-1, 120, 1, 1]               0
         SEModule-91          [-1, 120, 28, 28]               0
           Conv2d-92           [-1, 40, 28, 28]           4,800
      BatchNorm2d-93           [-1, 40, 28, 28]              80
         Identity-94           [-1, 40, 28, 28]               0
         baseConv-95           [-1, 40, 28, 28]               0
      bneckModule-96           [-1, 40, 28, 28]               0
           Conv2d-97          [-1, 240, 28, 28]           9,600
      BatchNorm2d-98          [-1, 240, 28, 28]             480
        Hardswish-99          [-1, 240, 28, 28]               0
        baseConv-100          [-1, 240, 28, 28]               0
          Conv2d-101          [-1, 240, 14, 14]           2,160
     BatchNorm2d-102          [-1, 240, 14, 14]             480
       Hardswish-103          [-1, 240, 14, 14]               0
        baseConv-104          [-1, 240, 14, 14]               0
          Conv2d-105           [-1, 80, 14, 14]          19,200
     BatchNorm2d-106           [-1, 80, 14, 14]             160
        Identity-107           [-1, 80, 14, 14]               0
        baseConv-108           [-1, 80, 14, 14]               0
     bneckModule-109           [-1, 80, 14, 14]               0
          Conv2d-110          [-1, 200, 14, 14]          16,000
     BatchNorm2d-111          [-1, 200, 14, 14]             400
       Hardswish-112          [-1, 200, 14, 14]               0
        baseConv-113          [-1, 200, 14, 14]               0
          Conv2d-114          [-1, 200, 14, 14]           1,800
     BatchNorm2d-115          [-1, 200, 14, 14]             400
       Hardswish-116          [-1, 200, 14, 14]               0
        baseConv-117          [-1, 200, 14, 14]               0
          Conv2d-118           [-1, 80, 14, 14]          16,000
     BatchNorm2d-119           [-1, 80, 14, 14]             160
        Identity-120           [-1, 80, 14, 14]               0
        baseConv-121           [-1, 80, 14, 14]               0
     bneckModule-122           [-1, 80, 14, 14]               0
          Conv2d-123          [-1, 184, 14, 14]          14,720
     BatchNorm2d-124          [-1, 184, 14, 14]             368
       Hardswish-125          [-1, 184, 14, 14]               0
        baseConv-126          [-1, 184, 14, 14]               0
          Conv2d-127          [-1, 184, 14, 14]           1,656
     BatchNorm2d-128          [-1, 184, 14, 14]             368
       Hardswish-129          [-1, 184, 14, 14]               0
        baseConv-130          [-1, 184, 14, 14]               0
          Conv2d-131           [-1, 80, 14, 14]          14,720
     BatchNorm2d-132           [-1, 80, 14, 14]             160
        Identity-133           [-1, 80, 14, 14]               0
        baseConv-134           [-1, 80, 14, 14]               0
     bneckModule-135           [-1, 80, 14, 14]               0
          Conv2d-136          [-1, 184, 14, 14]          14,720
     BatchNorm2d-137          [-1, 184, 14, 14]             368
       Hardswish-138          [-1, 184, 14, 14]               0
        baseConv-139          [-1, 184, 14, 14]               0
          Conv2d-140          [-1, 184, 14, 14]           1,656
     BatchNorm2d-141          [-1, 184, 14, 14]             368
       Hardswish-142          [-1, 184, 14, 14]               0
        baseConv-143          [-1, 184, 14, 14]               0
          Conv2d-144           [-1, 80, 14, 14]          14,720
     BatchNorm2d-145           [-1, 80, 14, 14]             160
        Identity-146           [-1, 80, 14, 14]               0
        baseConv-147           [-1, 80, 14, 14]               0
     bneckModule-148           [-1, 80, 14, 14]               0
          Conv2d-149          [-1, 480, 14, 14]          38,400
     BatchNorm2d-150          [-1, 480, 14, 14]             960
       Hardswish-151          [-1, 480, 14, 14]               0
        baseConv-152          [-1, 480, 14, 14]               0
          Conv2d-153          [-1, 480, 14, 14]           4,320
     BatchNorm2d-154          [-1, 480, 14, 14]             960
       Hardswish-155          [-1, 480, 14, 14]               0
        baseConv-156          [-1, 480, 14, 14]               0
AdaptiveAvgPool2d-157            [-1, 480, 1, 1]               0
          Conv2d-158            [-1, 120, 1, 1]          57,720
           ReLU6-159            [-1, 120, 1, 1]               0
          Conv2d-160            [-1, 480, 1, 1]          58,080
       Hardswish-161            [-1, 480, 1, 1]               0
        SEModule-162          [-1, 480, 14, 14]               0
          Conv2d-163          [-1, 112, 14, 14]          53,760
     BatchNorm2d-164          [-1, 112, 14, 14]             224
        Identity-165          [-1, 112, 14, 14]               0
        baseConv-166          [-1, 112, 14, 14]               0
     bneckModule-167          [-1, 112, 14, 14]               0
          Conv2d-168          [-1, 672, 14, 14]          75,264
     BatchNorm2d-169          [-1, 672, 14, 14]           1,344
       Hardswish-170          [-1, 672, 14, 14]               0
        baseConv-171          [-1, 672, 14, 14]               0
          Conv2d-172          [-1, 672, 14, 14]           6,048
     BatchNorm2d-173          [-1, 672, 14, 14]           1,344
       Hardswish-174          [-1, 672, 14, 14]               0
        baseConv-175          [-1, 672, 14, 14]               0
AdaptiveAvgPool2d-176            [-1, 672, 1, 1]               0
          Conv2d-177            [-1, 168, 1, 1]         113,064
           ReLU6-178            [-1, 168, 1, 1]               0
          Conv2d-179            [-1, 672, 1, 1]         113,568
       Hardswish-180            [-1, 672, 1, 1]               0
        SEModule-181          [-1, 672, 14, 14]               0
          Conv2d-182          [-1, 112, 14, 14]          75,264
     BatchNorm2d-183          [-1, 112, 14, 14]             224
        Identity-184          [-1, 112, 14, 14]               0
        baseConv-185          [-1, 112, 14, 14]               0
     bneckModule-186          [-1, 112, 14, 14]               0
          Conv2d-187          [-1, 672, 14, 14]          75,264
     BatchNorm2d-188          [-1, 672, 14, 14]           1,344
       Hardswish-189          [-1, 672, 14, 14]               0
        baseConv-190          [-1, 672, 14, 14]               0
          Conv2d-191            [-1, 672, 7, 7]          16,800
     BatchNorm2d-192            [-1, 672, 7, 7]           1,344
       Hardswish-193            [-1, 672, 7, 7]               0
        baseConv-194            [-1, 672, 7, 7]               0
AdaptiveAvgPool2d-195            [-1, 672, 1, 1]               0
          Conv2d-196            [-1, 168, 1, 1]         113,064
           ReLU6-197            [-1, 168, 1, 1]               0
          Conv2d-198            [-1, 672, 1, 1]         113,568
       Hardswish-199            [-1, 672, 1, 1]               0
        SEModule-200            [-1, 672, 7, 7]               0
          Conv2d-201            [-1, 160, 7, 7]         107,520
     BatchNorm2d-202            [-1, 160, 7, 7]             320
        Identity-203            [-1, 160, 7, 7]               0
        baseConv-204            [-1, 160, 7, 7]               0
     bneckModule-205            [-1, 160, 7, 7]               0
          Conv2d-206            [-1, 960, 7, 7]         153,600
     BatchNorm2d-207            [-1, 960, 7, 7]           1,920
       Hardswish-208            [-1, 960, 7, 7]               0
        baseConv-209            [-1, 960, 7, 7]               0
          Conv2d-210            [-1, 960, 7, 7]          24,000
     BatchNorm2d-211            [-1, 960, 7, 7]           1,920
       Hardswish-212            [-1, 960, 7, 7]               0
        baseConv-213            [-1, 960, 7, 7]               0
AdaptiveAvgPool2d-214            [-1, 960, 1, 1]               0
          Conv2d-215            [-1, 240, 1, 1]         230,640
           ReLU6-216            [-1, 240, 1, 1]               0
          Conv2d-217            [-1, 960, 1, 1]         231,360
       Hardswish-218            [-1, 960, 1, 1]               0
        SEModule-219            [-1, 960, 7, 7]               0
          Conv2d-220            [-1, 160, 7, 7]         153,600
     BatchNorm2d-221            [-1, 160, 7, 7]             320
        Identity-222            [-1, 160, 7, 7]               0
        baseConv-223            [-1, 160, 7, 7]               0
     bneckModule-224            [-1, 160, 7, 7]               0
          Conv2d-225            [-1, 960, 7, 7]         153,600
     BatchNorm2d-226            [-1, 960, 7, 7]           1,920
       Hardswish-227            [-1, 960, 7, 7]               0
        baseConv-228            [-1, 960, 7, 7]               0
          Conv2d-229            [-1, 960, 7, 7]          24,000
     BatchNorm2d-230            [-1, 960, 7, 7]           1,920
       Hardswish-231            [-1, 960, 7, 7]               0
        baseConv-232            [-1, 960, 7, 7]               0
AdaptiveAvgPool2d-233            [-1, 960, 1, 1]               0
          Conv2d-234            [-1, 240, 1, 1]         230,640
           ReLU6-235            [-1, 240, 1, 1]               0
          Conv2d-236            [-1, 960, 1, 1]         231,360
       Hardswish-237            [-1, 960, 1, 1]               0
        SEModule-238            [-1, 960, 7, 7]               0
          Conv2d-239            [-1, 160, 7, 7]         153,600
     BatchNorm2d-240            [-1, 160, 7, 7]             320
        Identity-241            [-1, 160, 7, 7]               0
        baseConv-242            [-1, 160, 7, 7]               0
     bneckModule-243            [-1, 160, 7, 7]               0
          Conv2d-244            [-1, 960, 7, 7]         153,600
     BatchNorm2d-245            [-1, 960, 7, 7]           1,920
       Hardswish-246            [-1, 960, 7, 7]               0
        baseConv-247            [-1, 960, 7, 7]               0
AdaptiveAvgPool2d-248            [-1, 960, 1, 1]               0
          Linear-249                 [-1, 1280]       1,230,080
       Hardswish-250                 [-1, 1280]               0
         Dropout-251                 [-1, 1280]               0
          Linear-252                   [-1, 10]          12,810
================================================================
Total params: 4,213,008
Trainable params: 4,213,008
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 143.36
Params size (MB): 16.07
Estimated Total Size (MB): 160.01
----------------------------------------------------------------

四、代码

import torch.nn as nn
from collections import OrderedDict
import torch
from torchsummary import summary

#定义基本的Conv_Bn_activate
class baseConv(nn.Module):
    def __init__(self,inchannel,outchannel,kernel_size,stride,groups=1,active=False,bias=False):
        super(baseConv, self).__init__()

        #定义使用的激活函数
        if active=='HS':
            ac=nn.Hardswish
        elif active=='RE':
            ac=nn.ReLU6
        else:
            ac=nn.Identity

        pad=kernel_size//2
        self.base=nn.Sequential(
            nn.Conv2d(in_channels=inchannel,out_channels=outchannel,kernel_size=kernel_size,stride=stride,padding=pad,groups=groups,bias=bias),
            nn.BatchNorm2d(outchannel),
            ac()
        )
    def forward(self,x):
        x=self.base(x)
        return x

#定义SE模块
class SEModule(nn.Module):
    def __init__(self,inchannels):
        super(SEModule, self).__init__()
        hidden_channel=int(inchannels/4)
        self.pool=nn.AdaptiveAvgPool2d((1,1))
        self.linear1=nn.Sequential(
            nn.Conv2d(inchannels,hidden_channel,1),
            nn.ReLU6()
        )
        self.linear2=nn.Sequential(
            nn.Conv2d(hidden_channel,inchannels,1),
            nn.Hardswish()
        )

    def forward(self,x):
        out=self.pool(x)
        out=self.linear1(out)
        out=self.linear2(out)
        return out*x

#定义bneck模块
class bneckModule(nn.Module):
    def __init__(self,inchannels,expand_channels,outchannels,kernel_size,stride,SE,activate):
        super(bneckModule, self).__init__()
        self.module=[]     #存放module

        if inchannels!=expand_channels:         #只有不相等时候才有第一层的升维操作
            self.module.append(baseConv(inchannels,expand_channels,kernel_size=1,stride=1,active=activate))

        self.module.append(baseConv(expand_channels,expand_channels,kernel_size=kernel_size,stride=stride,active=activate,groups=expand_channels))

        #判断是否有se模块
        if SE==True:
            self.module.append(SEModule(expand_channels))

        self.module.append(baseConv(expand_channels,outchannels,1,1))
        self.module=nn.Sequential(*self.module)

        #判断是否有残差结构
        self.residual=False
        if inchannels==outchannels and stride==1:
            self.residual=True

    def forward(self,x):
        out1=self.module(x)
        if self.residual:
            return out1+x
        else:
            return out1


#定义v3结构
class mobilenet_v3(nn.Module):
    def __init__(self,num_classes,init_weight=True):
        super(mobilenet_v3, self).__init__()

        # [inchannel,expand_channels,outchannels,kernel_size,stride,SE,activate]
        net_config = [[16, 16, 16, 3, 1, False, 'HS'],
                      [16, 64, 24, 3, 2, False, 'RE'],
                      [24, 72, 24, 3, 1, False, 'RE'],
                      [24, 72, 40, 5, 2, True, 'RE'],
                      [40, 120, 40, 5, 1, True, 'RE'],
                      [40, 120, 40, 5, 1, True, 'RE'],
                      [40, 240, 80, 3, 2, False, 'HS'],
                      [80, 200, 80, 3, 1, False, 'HS'],
                      [80, 184, 80, 3, 1, False, 'HS'],
                      [80, 184, 80, 3, 1, False, 'HS'],
                      [80, 480, 112, 3, 1, True, 'HS'],
                      [112, 672, 112, 3, 1, True, 'HS'],
                      [112, 672, 160, 5, 2, True, 'HS'],
                      [160, 960, 160, 5, 1, True, 'HS'],
                      [160, 960, 160, 5, 1, True, 'HS']]

        #定义一个有序字典存放网络结构
        modules=OrderedDict()
        modules.update({'layer1':baseConv(inchannel=3,kernel_size=3,outchannel=16,stride=2,active='HS')})

        #开始配置
        for idx,layer in enumerate(net_config):
            modules.update({'bneck_{}'.format(idx):bneckModule(layer[0],layer[1],layer[2],layer[3],layer[4],layer[5],layer[6])})

        modules.update({'conv_1*1':baseConv(layer[2],960,1,stride=1,active='HS')})
        modules.update({'pool':nn.AdaptiveAvgPool2d((1,1))})

        self.module=nn.Sequential(modules)

        self.classifier=nn.Sequential(
            nn.Linear(960,1280),
            nn.Hardswish(),
            nn.Dropout(p=0.2),
            nn.Linear(1280,num_classes)
        )

        if init_weight:
            self.init_weight()

    def init_weight(self):
        for w in self.modules():
            if isinstance(w, nn.Conv2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_out')
                if w.bias is not None:
                    nn.init.zeros_(w.bias)
            elif isinstance(w, nn.BatchNorm2d):
                nn.init.ones_(w.weight)
                nn.init.zeros_(w.bias)
            elif isinstance(w, nn.Linear):
                nn.init.normal_(w.weight, 0, 0.01)
                nn.init.zeros_(w.bias)


    def forward(self,x):
        out=self.module(x)
        out=out.view(out.size(0),-1)
        out=self.classifier(out)
        return out


if __name__ == '__main__':
    net=mobilenet_v3(10).to('cuda')
    summary(net,(3,224,224))

        训练部分以及验证部分的代码与MobileNet_v2类似。

五、训练结果

        这里我迭代了9个epoch,准确率达到了79%左右(效果还是很不错的);

        可视化训练损失以及验证集准确率(在终端输入tensorboard --logdir=runs):

PyTorch 实现MobileNet_v3在CIFAR10上图像分类_第5张图片

六、完整代码

        代码地址:链接:https://pan.baidu.com/s/1dgLLRdco_kYPxEkNfFs3yw  提取码:brsm

        权重下载地址:https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth

你可能感兴趣的:(图像分类,pytorch,python,深度学习)