EfficientNet 分类花数据集

目录

1. EfficientNet 网络

2. depth、width、resolution

3. EfficientNet 网络的结构

4. dos 命令train 网络

5. 代码

5.1 model

5.2 dataset

5.3 utils

5.4 train

5.5 predict


1. EfficientNet 网络

EfficientNet 对网络的重要三个参数进行的探索:图像分辨率、网络的宽度、网络的深度

图像分辨率:特征图的size,h*w就是图像的空间分辨率

网络的宽度:网络中特征图的个数,也就是卷积核的个数或者输出的channel

网络的深度:网络的层数,resnet34,resnet101等等

如下:

EfficientNet 分类花数据集_第1张图片

 

2. depth、width、resolution

不知道从什么时候开始,224*224的图像分辨率输入似乎成为了神经网络的输入标准,导致后来的网络几乎输入都是224*224的尺寸大小

虽然有的网络规定是224*224大小,但是输入是别的尺寸例如300*300也没问题。这个是有问题的,因为大多数的代码,在全连接层之前用的是自适应池化层。

否则,输入的图像尺寸不正确会影响到全连接层的参数,就会报错

因此,在规定了分辨率的这一基础下,后面的网络都在width或者depth上面下功夫。例如resnet可以增加到1000层的深度

下面简单说说三个参数的作用

宽度:增加channel的数量 ,更广泛的网络往往能够捕获更细粒度的特征,并且更容易训练。然而,极宽但较浅的网络往往难以捕捉更高层次的特征。经验结果表明,当网络变得更宽且w更大时,精度很快饱和。

深度:增加网络的层数,缩放网络深度是许多卷积神经网络最常用的方法。更深入的ConvNet可以捕获更丰富和更复杂的特征,并在新任务上很好地泛化。然而,由于梯度消失问题,更深层次的网络也更难训练。尽管一些技术,如shortcut和批量归一化缓解了训练问题,但深度网络的精度增益会降低

分辨率:使用更高分辨率的输入图像,卷积可以潜在地捕获更细粒度的模式。其中更高的分辨率确实可以提高精度,但对于非常高的分辨率,精度增益会减少

作者得出的结论:

EfficientNet 分类花数据集_第2张图片

EfficientNet 分类花数据集_第3张图片

EfficientNet 提出,将这三个参数如何平衡的缩放是很重要的。因为,不同尺度尺度之间并不是相互独立的。直观地说,对于更高分辨率的图像,我们应该增加网络深度,这样更大的接受域可以帮助捕获在更大的图像中包含更多像素的相似特征。相应的,在分辨率较高时,也应增加网络宽度为了在高分辨率图像中捕获更多像素的细粒度模式。这些直觉表明,我们需要协调和平衡不同的缩放维度,而不是传统的一维缩放。

EfficientNet 分类花数据集_第4张图片

 

3. EfficientNet 网络的结构

EfficientNet 网络的基本模块称为 MBConv,首先采用1*1卷积进行升维度,然后dw卷积,然后经过了SE注意力机制,在1*1卷积降维,经过dropout。如果用shortcut的话,加在一起输出

EfficientNet 分类花数据集_第5张图片

 

其中SE注意力机制如下:

EfficientNet 分类花数据集_第6张图片

 

然后EfficientNet B0的结构如下:

EfficientNet 分类花数据集_第7张图片

 

EfficientNet B1 - B7 就是在B0的基础上增加了宽度和深度的超参数,当改变这两个数的时候,输入图像的size要手动的根据表格改变

EfficientNet 分类花数据集_第8张图片

 

4. dos 命令train 网络

-h 查看可以定义的参数,这里将epochs 设定为30

EfficientNet 分类花数据集_第9张图片

训练过程:

EfficientNet 分类花数据集_第10张图片

 

预测:这里只预测单张图像

EfficientNet 分类花数据集_第11张图片

EfficientNet 分类花数据集_第12张图片

 

5. 代码

EfficientNet 网络的代码

 

5.1 model

import math
import copy
from functools import partial
from collections import OrderedDict
from typing import Optional, Callable

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F


# 传入channel个数调整到最近的8的整数倍,硬件运行方便
def _make_divisible(ch, divisor=8, min_ch=None):

    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch


def drop_path(x, drop_prob: float = 0., training: bool = False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf

    This function is taken from the rwightman.
    It can be seen here:
    https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py#L140
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


# 卷积 + BN + Swish 激活函数
class ConvBNActivation(nn.Sequential):
    def __init__(self,
                 in_planes: int,    # 输入 channel
                 out_planes: int,   # 输出 channel
                 kernel_size: int = 3,
                 stride: int = 1,
                 groups: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None,
                 activation_layer: Optional[Callable[..., nn.Module]] = None):
        padding = (kernel_size - 1) // 2    # 保证same卷积,stride 控制size是否减半
        if norm_layer is None:      # BN 层
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:    # 激活函数
            activation_layer = nn.SiLU      # Swish  (torch>=1.7)

        super(ConvBNActivation, self).__init__(nn.Conv2d(in_channels=in_planes,
                                                         out_channels=out_planes,
                                                         kernel_size=kernel_size,
                                                         stride=stride,
                                                         padding=padding,
                                                         groups=groups,
                                                         bias=False),
                                               norm_layer(out_planes),
                                               activation_layer())


# SE 模块
class SqueezeExcitation(nn.Module):
    def __init__(self,
                 input_c: int,   # input channel (MBConv 输入的channel)
                 expand_c: int,  # expand channel (MBConv 第一个 Conv 升维之后的channel)
                 squeeze_factor: int = 4):
        super(SqueezeExcitation, self).__init__()
        squeeze_c = input_c // squeeze_factor   # 第一个全连接层节点个数
        self.fc1 = nn.Conv2d(expand_c, squeeze_c, 1)    # 用1*1卷积代替全连接层
        self.ac1 = nn.SiLU()  # alias Swish
        self.fc2 = nn.Conv2d(squeeze_c, expand_c, 1)
        self.ac2 = nn.Sigmoid()

    def forward(self, x: Tensor) -> Tensor:
        scale = F.adaptive_avg_pool2d(x, output_size=(1, 1))
        scale = self.fc1(scale)
        scale = self.ac1(scale)
        scale = self.fc2(scale)
        scale = self.ac2(scale)
        return scale * x


# 倒残差结构,MBConv 参数
class InvertedResidualConfig:
    def __init__(self,
                 kernel: int,          # 3 or 5
                 input_c: int,         # 输入的 channel
                 out_c: int,           # 输出的 channel
                 expanded_ratio: int,  # 1 or 6
                 stride: int,          # 1 or 2
                 use_se: bool,         # True
                 drop_rate: float,     # dropout 比例
                 index: str,           # 1a, 2a, 2b, ...,MBConv 的名称
                 width_coefficient: float):     # width 倍率因子
        self.input_c = self.adjust_channels(input_c, width_coefficient)  # MBConv的输入channel,B1~B7 都是再 B0的基础上 × α
        self.kernel = kernel
        self.expanded_c = self.input_c * expanded_ratio     # 第一个1*1卷积核升维的channel个数
        self.out_c = self.adjust_channels(out_c, width_coefficient)
        self.use_se = use_se    # SE 模块
        self.stride = stride
        self.drop_rate = drop_rate
        self.index = index

    @staticmethod   # 调整channel为8的整数倍
    def adjust_channels(channels: int, width_coefficient: float):
        return _make_divisible(channels * width_coefficient, 8)


# MBConv 模块
class InvertedResidual(nn.Module):
    def __init__(self,
                 cnf: InvertedResidualConfig,   # 倒残差结构,MBConv 参数
                 norm_layer: Callable[..., nn.Module]):
        super(InvertedResidual, self).__init__()

        if cnf.stride not in [1, 2]:    # stride 只有1或者2
            raise ValueError("illegal stride value.")

        self.use_res_connect = (cnf.stride == 1 and cnf.input_c == cnf.out_c)   # shortcut

        layers = OrderedDict()      # 有序字典,搭建网络
        activation_layer = nn.SiLU  # Swish

        # expand
        if cnf.expanded_c != cnf.input_c:   # 有升维的时候,没有升维,直接跳过
            layers.update({"expand_conv": ConvBNActivation(cnf.input_c,
                                                           cnf.expanded_c,
                                                           kernel_size=1,
                                                           norm_layer=norm_layer,
                                                           activation_layer=activation_layer)})

        # depthwise
        layers.update({"dwconv": ConvBNActivation(cnf.expanded_c,
                                                  cnf.expanded_c,
                                                  kernel_size=cnf.kernel,
                                                  stride=cnf.stride,
                                                  groups=cnf.expanded_c,    # dw卷积
                                                  norm_layer=norm_layer,
                                                  activation_layer=activation_layer)})

        if cnf.use_se:
            layers.update({"se": SqueezeExcitation(cnf.input_c,cnf.expanded_c)})

        # project,SE后面的 1*1卷积
        layers.update({"project_conv": ConvBNActivation(cnf.expanded_c,
                                                        cnf.out_c,
                                                        kernel_size=1,
                                                        norm_layer=norm_layer,
                                                        activation_layer=nn.Identity)}) # identity 没有激活函数

        self.block = nn.Sequential(layers)  # MBConv
        self.out_channels = cnf.out_c
        self.is_strided = cnf.stride > 1

        # 只有在使用shortcut连接时才使用dropout层
        if self.use_res_connect and cnf.drop_rate > 0:
            self.dropout = DropPath(cnf.drop_rate)
        else:
            self.dropout = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        result = self.block(x)
        result = self.dropout(result)
        if self.use_res_connect:
            result += x

        return result


# efficientNet 网络的实现
class EfficientNet(nn.Module):
    def __init__(self,
                 width_coefficient: float,      # 宽度方向因子
                 depth_coefficient: float,      # 深度方向因子
                 num_classes: int = 1000,       # 分类个数
                 dropout_rate: float = 0.2,     # 全连接层的dropout,根据不同的B0-B7会变
                 drop_connect_rate: float = 0.2,    # MBConv 中的dropout,都是0.2
                 block: Optional[Callable[..., nn.Module]] = None,
                 norm_layer: Optional[Callable[..., nn.Module]] = None
                 ):
        super(EfficientNet, self).__init__()

        # stage 2 ~ stage 8的结构参数
        # kernel_size, in_channel, out_channel, exp_ratio, strides, use_SE, drop_connect_rate, repeats
        default_cnf = [[3, 32, 16, 1, 1, True, drop_connect_rate, 1],
                       [3, 16, 24, 6, 2, True, drop_connect_rate, 2],
                       [5, 24, 40, 6, 2, True, drop_connect_rate, 2],
                       [3, 40, 80, 6, 2, True, drop_connect_rate, 3],
                       [5, 80, 112, 6, 1, True, drop_connect_rate, 3],
                       [5, 112, 192, 6, 2, True, drop_connect_rate, 4],
                       [3, 192, 320, 6, 1, True, drop_connect_rate, 1]]

        def round_repeats(repeats): # 不同的B0 - B7,加深网络层要取整
            return int(math.ceil(depth_coefficient * repeats))

        if block is None:
            block = InvertedResidual

        if norm_layer is None:
            norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.1)

        # 调整宽度倍率因子
        adjust_channels = partial(InvertedResidualConfig.adjust_channels,width_coefficient=width_coefficient)

        # build inverted_residual_setting
        bneck_conf = partial(InvertedResidualConfig,width_coefficient=width_coefficient)

        b = 0
        num_blocks = float(sum(round_repeats(i[-1]) for i in default_cnf))  # MBConv 重复的次数
        inverted_residual_setting = []
        for stage, args in enumerate(default_cnf):  # 遍历 stage
            cnf = copy.copy(args)
            for i in range(round_repeats(cnf.pop(-1))):    # 遍历stage 中 MBConv 模块
                if i > 0:
                    # strides equal 1 except first cnf
                    cnf[-3] = 1  # strides
                    cnf[1] = cnf[2]  # input_channel equal output_channel

                cnf[-1] = args[-2] * b / num_blocks  # update dropout ratio
                index = str(stage + 1) + chr(i + 97)  # 1a, 2a, 2b, ...
                inverted_residual_setting.append(bneck_conf(*cnf, index))
                b += 1

        # create layers
        layers = OrderedDict()

        # first conv
        layers.update({"stem_conv": ConvBNActivation(in_planes=3,
                                                     out_planes=adjust_channels(32),
                                                     kernel_size=3,
                                                     stride=2,
                                                     norm_layer=norm_layer)})

        # building inverted residual blocks
        for cnf in inverted_residual_setting:
            layers.update({cnf.index: block(cnf, norm_layer)})

        # build top
        last_conv_input_c = inverted_residual_setting[-1].out_c
        last_conv_output_c = adjust_channels(1280)
        layers.update({"top": ConvBNActivation(in_planes=last_conv_input_c,
                                               out_planes=last_conv_output_c,
                                               kernel_size=1,
                                               norm_layer=norm_layer)})

        self.features = nn.Sequential(layers)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        classifier = []     # 分类部分
        if dropout_rate > 0:
            classifier.append(nn.Dropout(p=dropout_rate, inplace=True))
        classifier.append(nn.Linear(last_conv_output_c, num_classes))
        self.classifier = nn.Sequential(*classifier)

        # initial weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def efficientnet_b0(num_classes=1000):
    # input image size 224x224
    return EfficientNet(width_coefficient=1.0,depth_coefficient=1.0,dropout_rate=0.2,num_classes=num_classes)


def efficientnet_b1(num_classes=1000):
    # input image size 240x240
    return EfficientNet(width_coefficient=1.0,depth_coefficient=1.1,dropout_rate=0.2,num_classes=num_classes)


def efficientnet_b2(num_classes=1000):
    # input image size 260x260
    return EfficientNet(width_coefficient=1.1,depth_coefficient=1.2,dropout_rate=0.3,num_classes=num_classes)


def efficientnet_b3(num_classes=1000):
    # input image size 300x300
    return EfficientNet(width_coefficient=1.2,depth_coefficient=1.4,dropout_rate=0.3,num_classes=num_classes)


def efficientnet_b4(num_classes=1000):
    # input image size 380x380
    return EfficientNet(width_coefficient=1.4,depth_coefficient=1.8,dropout_rate=0.4,num_classes=num_classes)


def efficientnet_b5(num_classes=1000):
    # input image size 456x456
    return EfficientNet(width_coefficient=1.6,depth_coefficient=2.2,dropout_rate=0.4,num_classes=num_classes)


def efficientnet_b6(num_classes=1000):
    # input image size 528x528
    return EfficientNet(width_coefficient=1.8,depth_coefficient=2.6,dropout_rate=0.5,num_classes=num_classes)


def efficientnet_b7(num_classes=1000):
    # input image size 600x600
    return EfficientNet(width_coefficient=2.0,depth_coefficient=3.1,dropout_rate=0.5,num_classes=num_classes)

5.2 dataset

from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 官方实现的default_collate可以参考
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels

5.3 utils

import os
import sys
import json
import random
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt


def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证各平台顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]
        # 排序,保证各平台顺序一致
        images.sort()
        # 获取该类别对应的索引
        image_class = class_indices[cla]
        # 记录该类别的样本数量
        every_class_num.append(len(images))
        # 按比例随机采样验证样本
        val_path = random.sample(images, k=int(len(images) * val_rate))

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:  # 否则存入训练集
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    print("{} images were found in the dataset.".format(sum(every_class_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))
    assert len(train_images_path) > 0, "number of training images must greater than 0."
    assert len(val_images_path) > 0, "number of validation images must greater than 0."

    plot_image = False
    if plot_image:
        # 绘制每种类别个数柱状图
        plt.bar(range(len(flower_class)), every_class_num, align='center')
        # 将横坐标0,1,2,3,4替换为相应的类别名称
        plt.xticks(range(len(flower_class)), flower_class)
        # 在柱状图上添加数值标签
        for i, v in enumerate(every_class_num):
            plt.text(x=i, y=v + 5, s=str(v), ha='center')
        # 设置x坐标
        plt.xlabel('image class')
        # 设置y坐标
        plt.ylabel('number of images')
        # 设置柱状图的标题
        plt.title('flower class distribution')
        plt.show()

    return train_images_path, train_images_label, val_images_path, val_images_label


def train_one_epoch(model, optimizer, data_loader, device, epoch,batch):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()         # 损失函数
    running_loss = 0.0                                     # 一个 epoch 的损失
    data_loader = tqdm(data_loader, file=sys.stdout)

    for images, labels in data_loader:

        optimizer.zero_grad()               # 梯度清零
        pred = model(images.to(device))     # forward
        loss = loss_function(pred, labels.to(device))   # 计算损失
        loss.backward()                     # 反向传播
        running_loss += loss.item()
        optimizer.step()                    # 梯度更新

        data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(loss.item()/batch, 3))

    return running_loss


@torch.no_grad()
def evaluate(model, data_loader, device):
    model.eval()
    sum_num = 0    # 用于存储预测正确的样本个数
    data_loader = tqdm(data_loader, file=sys.stdout)

    for images, labels in data_loader:
        pred = model(images.to(device))
        pred = torch.max(pred, dim=1)[1]
        sum_num += torch.eq(pred, labels.to(device)).sum()

    return sum_num

5.4 train

import os
import math
import argparse

import torch
import torch.optim as optim
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from torch.utils.data import DataLoader
from model import efficientnet_b0 as create_model   # 导入efficientNet 网络
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    if os.path.exists("./weights") is False:    # 创建保留权重的文件夹
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    # 不同的b0-b7,输入的size不同
    img_size = {"B0": 224,"B1": 240,"B2": 260,"B3": 300,"B4": 380,"B5": 456,"B6": 528,"B7": 600}
    num_model = "B0"    # 这里是 B0

    # 预处理
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model]),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(img_size[num_model]),
                                   transforms.CenterCrop(img_size[num_model]),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    # 实例化数据集
    train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])
    val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])

    # 样本个数
    num_trainSet = len(train_dataset)
    num_valSet = len(val_dataset)

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers

    # 加载数据集
    train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)

    # 如果实例化网络
    model = create_model(num_classes=args.num_classes).to(device)

    if args.weights != "":      # 载入预训练模型
        if os.path.exists(args.weights):
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items() if model.state_dict()[k].numel() == v.numel()}
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    if args.freeze_layers:     # 冻结权重
        for name, para in model.named_parameters():
            # 除最后一个卷积层和全连接层外,其他权重全部冻结
            if ("features.top" not in name) and ("classifier" not in name):
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]         # 只对没有冻结的参数优化
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)

    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    best_acc = 0.0
    for epoch in range(args.epochs):
        # train
        loss_all = train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, device=device, epoch=epoch, batch=batch_size)

        scheduler.step()

        # validate
        acc_all = evaluate(model=model, data_loader=val_loader, device=device)

        print("[epoch :%d],train loss:%.4f ,test accuracy: %.4f" % (epoch, loss_all / num_trainSet, acc_all / num_valSet))

        if acc_all > best_acc:  # 保留最好的权重
            best_acc = acc_all
            torch.save(model.state_dict(),'./weights/shuffleNet_V2.pth')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--lrf', type=float, default=0.01)

    parser.add_argument('--data-path', type=str,default="./data/flower")     # 数据集,这里没有划分训练集和测试集
    parser.add_argument('--weights', type=str, default='',help='initial weights path')  # 没有的话设置为''
    parser.add_argument('--freeze-layers', type=bool, default=False)        # 是否冻结特征提取层
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    print('start training....')
    main(opt)
    print("finish training!!!")

5.5 predict

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import efficientnet_b0 as create_model


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    img_size = {"B0": 224,"B1": 240,"B2": 260,"B3": 300,"B4": 380,"B5": 456,"B6": 528,"B7": 600}
    num_model = "B0"

    # 预处理
    data_transform = transforms.Compose(
        [transforms.Resize(img_size[num_model]),
         transforms.CenterCrop(img_size[num_model]),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "./OIP-C.jpg"
    img = Image.open(img_path)
    plt.imshow(img)

    img = data_transform(img)
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = create_model(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./weights/shuffleNet_V2.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()

    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))     # 打印预测结果
    plt.show()


if __name__ == '__main__':
    main()

你可能感兴趣的:(图像分类,深度学习,计算机视觉,神经网络)