[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)

1. ResNet理论部分

网络的亮点

  1. 超深的网络结构(突破1000层)
  2. 提出Residual模块
  3. 使用BN(Batch Normalization)加速训练(不使用Dropout)

1.1 Residual结构(残差结构)

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第1张图片
左边是ResNet-34的Block构成,右边是ResNet-50/101/152以至于更深网络使用的Block结构。

其中, ⊕ \oplus 表示两个形状相同的tensor对应位置元素相加, 1 × 1 1 \times 1 1×1 卷积用来升维和降维,代码表示为:

# 升维使用的卷积
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channels

# 降维使用的卷积
self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion, kernel_size=1, stride=1, bias=False)  # unsqueeze channels

1.1.1 两种Residual结构可训练参数量对比

其实到这里我们应该有一个疑问,为什么构造浅层网络要使用A形态(左边的结构),构造深层网络使用B形态(右边的结构)?

主要原因是,相比B,A拥有更多的参数量,该模块的表示能力对应增强。但如果在构造深层结构时使用A结构就会引发一个问题——模型需要训练的参数量太大了,虽然模型的表示能力很强,但训练时间太久。为了解决这个问题,B结构被提出,目的就是为了解决参数量大的问题。

这里我们对比一下AB的可训练参数量(假设输入的通道数为 256 256 256):

P a r a m s A = 256 × 3 × 3 × 256 + 256 × 3 × 3 × 256 = 1 , 179 , 648 P a r a m s B = 256 × 1 × 1 × 64 + 64 × 3 × 3 × 64 + 64 × 1 × 1 × 256 = 69 , 632 \mathrm{Params_A} = 256 \times 3 \times 3 \times 256 + 256 \times 3 \times 3 \times 256 = 1,179,648 \\ \mathrm{Params_B} = 256 \times 1 \times 1 \times 64 + 64 \times 3 \times 3 \times 64 + 64 \times 1 \times 1 \times 256 = 69, 632 ParamsA=256×3×3×256+256×3×3×256=1,179,648ParamsB=256×1×1×64+64×3×3×64+64×1×1×256=69,632

我们应该明白了,在使用残差结构构造深层网络时,如果使用A结构,那么网络的参数量太大了,训练昂贵。

1.2 需进行下采样的残差结构

其实很容易想到,Residual中 ⊕ \oplus 需要两个矩阵的shape相同,但如果我们需要进行下采样该怎么办?

1.2.1 浅层残差结构——A

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第2张图片
很容易理解,不再赘述。

1.2.2 深层残差结构——B

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第3张图片
注意是3×3卷积进行的下采样。

1.2 ResNet网络结构

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第4张图片

2. Batch Normalization(BN)

Batch Normalization是google团队在2015年论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》提出的。通过该方法能够加速网络的收敛并提升准确率。

2.1 BN的目的

Batch Normalization 的目的是使我们的一批(1个Batch)的特征图满足均值为0,方差为1的高斯分布(正态分布)。

2.2 Batch Normalization的原理

我们在图像预处理过程中通常会对图像进行标准化处理,这样能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律)。而Batch Normalization的目的就是使我们的feature map满足均值为0,方差为1的分布规律。

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第5张图片


[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第6张图片

“对于一个拥有 d d d维的输入 x x x,我们将对它的每一个维度进行标准化处理。” 假设我们输入的 x x x 是RGB三通道的彩色图像,那么这里的 d d d 就是输入图像的 c h a n n e l s channels channels d = 3 d=3 d=3 x = ( x ( 1 ) , x ( 2 ) , x ( 3 ) ) x=(x^(1), x^{(2)}, x^{(3)}) x=(x(1),x(2),x(3)),其中 x ( 1 ) x^{(1)} x(1) 就代表我们的R通道所对应的特征矩阵,依此类推。标准化处理也就是分别对我们的R通道,G通道,B通道进行处理。上面的公式不用看,原文提供了更加详细的计算公式:

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第7张图片

  • 首先计算每个batch同一个通道所有的对应的均值 μ B \mu_{\mathcal{B}} μB 和方差 σ B 2 \sigma^2_{\mathcal{B}} σB2
  • 然后对原参数进行标准化,即可得到经标准化处理后的数值 x i ^ \hat{x_i} xi^,其中 ϵ \epsilon ϵ 为极小数(防止分母为0)。
  • 最后通过 γ \gamma γ β \beta β 对特征图的数值进一步调整,其中 γ \gamma γ β \beta β 分别用于调整方差均值的大小。如果不进行 γ \gamma γ β \beta β 调整,那么整批(Batch)的数据符合均值为0,方差为1的高斯分布规律。

均值为0, 方差为1的高斯分布不好吗,为什么还要进行调整?
对于不同的数据集来说,高斯分布不一定是最好的,所以BN有两个可以学习的参数 γ , β \gamma, \beta γ,β,通过反向传播进行学习和更新
Note:

  • γ \gamma γ 是用来调整数值分布的方差大小 β \beta β是用来调节数值均值的位置(均值的中心位置)。这两个参数是在反向传播过程中学习并更新的,而不像均值和方差那样正向传播中更新的
  • 均值 μ B \mu_{\mathcal{B}} μB 和方差 σ B 2 \sigma^2_{\mathcal{B}} σB2 的默认值分别为 0 0 0 1 1 1

我们刚刚有说让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后在进行标准化处理,对于一个大型的数据集明显是不可能的(和SGD的动机类似),所以论文中说的是Batch Normalization,也就是我们计算一个Batch数据的feature map然后在进行标准化(batch越大越接近整个数据集的分布,效果越好)。

我们根据上图的公式可以知道代表着我们计算的feature map每个维度(channel)的均值,注意 μ B \mu_{\mathcal{B}} μB 是一个向量不是一个值(数量维度就是输入特征图的Channel维度大小),向量 μ B \mu_{\mathcal{B}} μB 的每一个元素代表着一个维度(channel)的均值。 σ B 2 \sigma^2_{\mathcal{B}} σB2 代表着我们计算的feature map每个维度(channel)的方差,注意 σ B 2 \sigma^2_{\mathcal{B}} σB2 是一个向量不是一个值,向量的每一个元素代表着一个维度(channel)的方差,然后根据 μ B \mu_{\mathcal{B}} μB σ B 2 \sigma^2_{\mathcal{B}} σB2 计算标准化处理后得到的值。下图给出了一个计算均值 μ B \mu_{\mathcal{B}} μB 和方差 σ B 2 \sigma^2_{\mathcal{B}} σB2 的示例:

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第8张图片

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程,假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel为2,那么 x ( 1 ) x^{(1)} x(1) 代表该batch的所有feature的 c h a n n e l 1 \mathrm{channel_1} channel1 的数据,同理 x ( 2 ) x^{(2)} x(2)代表该batch的所有feature的 c h a n n e l 2 \mathrm{channel_2} channel2的数据。然后分别计算 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2) 的均值与方差,得到 μ B \mu_{\mathcal{B}} μB σ B 2 \sigma^2_{\mathcal{B}} σB2 两个向量。然后再根据标准差计算公式分别计算每个channel的值。

在我们训练网络的过程中,我们是通过一个batch一个batch的数据进行训练的,但是我们在预测过程中通常都是输入一张图片进行预测,此时batch size为1,如果再通过上述方法计算均值和方差就没有意义了。所以我们在训练过程中要去不断的计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在训练完后我们可以近似认为所统计的均值和方差就等于整个训练集的均值和方差

最后在我们验证以及预测过程中,就使用统计得到的均值和方差进行标准化处理

Note: 均值 μ B \mu_{\mathcal{B}} μB 和方差 σ B 2 \sigma^2_{\mathcal{B}} σB2 并不是一个值,而是一个向量,第一个维度(表示数量)就是输入特征图的Channel维度大小。这也解释了为什么nn.BatchNorm2d/3d(维度)需要维度参数了

2.3 使用PyTorch进行BN的复现

在训练过程中,均值 μ B \mu_{\mathcal{B}} μB 和方差 σ B 2 \sigma^2_{\mathcal{B}} σB2 是通过计算当前Batch数据得到的记为为 μ n o w \mu _{now} μnow σ n o w 2 \sigma _{now}^{2} σnow2,而验证以及预测过程中所使用的均值方差是一个统计量记为 μ s t a t i s t i c \mu _{\mathrm{statistic}} μstatistic σ s t a t i s t i c 2 \sigma _{\mathrm{statistic}}^{2} σstatistic2。二者的具体更新策略如下,其中 m o m e n t u m \mathrm{momentum} momentum默认为0.1:

μ s t a t i s t i c + 1 = ( 1 − m o m e n t u m ) ∗ μ s t a t i s t i c + m o m e n t u m ∗ μ n o w σ s t a t i s t i c + 1 2 = ( 1 − m o m e n t u m ) ∗ σ s t a t i s t i c 2 + m o m e n t u m ∗ σ n o w \mathrm{ \mu_{statistic + 1} = (1 - momentum) * \mu_{statistic} + momentum * \mu_{now} } \\ \mathrm{ \sigma_{statistic + 1}^2 = (1 - momentum) * \sigma_{statistic}^2 + momentum * \sigma_{now} } μstatistic+1=(1momentum)μstatistic+momentumμnowσstatistic+12=(1momentum)σstatistic2+momentumσnow

这里要注意一下,在pytorch中对当前批次feature进行BN处理时所使用的 σ n o w 2 \sigma _{now}^{2} σnow2总体标准差,计算公式如下:

σ n o w 2 = 1 m ∑ i = 1 m ( x i − μ n o w ) 2 \mathrm{ \sigma_{now}^2 = \frac{1}{m} \sum^m_{i=1} (x_i - \mu_{now})^2 } σnow2=m1i=1m(xiμnow)2

在更新统计量 σ s t a t i s t i c 2 \sigma _{statistic}^{2} σstatistic2 时采用的 σ n o w 2 \sigma _{now}^{2} σnow2样本标准差,计算公式如下:

σ n o w 2 = 1 m − 1 ∑ i = 1 m ( x i − μ n o w ) 2 \mathrm{ \sigma_{now}^2 = \frac{1}{m-1} \sum^m_{i=1} (x_i - \mu_{now})^2 } σnow2=m11i=1m(xiμnow)2

下面是使用PyTorch做的测试,代码如下:

  1. bn_process函数是自定义的BN处理方法,用来验证是否和使用官方BN处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。
  2. 初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化 γ = 1 , β = 0 \gamma=1,\beta=0 γ=1β=0
import numpy as np
import torch.nn as nn
import torch


def bn_process(feature, mean, var):
    feature_shape = feature.shape  # [BS, C, H, W] = [2, 2, 2, 2]
    for i in range(feature_shape[1]):  # 遍历Channel维度
        feature_t = feature[:, i, :, :]  # channel-wise取出数据

        """
            std()是计算标准差的函数,使用时要额外注意ddof这个参数:
                在ddof = 0时,计算的是总体标准偏差,标准差公式根号内除以 n。
                在ddof = 1时,计算的是样本标准差,标准差公式根号内除以 (n-1)。
        """
        mean_t = feature_t.mean()  # 求均值\mu
        std_t1 = feature_t.std()  # 总体标准差 \sigma
        std_t2 = feature_t.std(ddof=1)  # 样本标准差 \sigma

        # 对数据进行标准化处理
        feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)

        # 使用均值和样本标准差更新统计均值和标准差。
        mean[i] = mean[i] * 0.9 + mean_t * 0.1
        var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
        # 这里并没有更新\gamma和\beta的代码,二者是通过反向传播学习、更新的,并不是通过正向传播!

    print(feature)  # 打印BN后的特征图


if __name__ == '__main__':
    # 随机生成一个batch为2,channel为2,height=width=2的特征向量
    # [batch, channel, height, width]
    feature1 = torch.randn(2, 2, 2, 2)
    # 初始化统计均值和方差
    calculate_mean = [0.0, 0.0]  # \gamma
    calculate_var = [1.0, 1.0]  # \beta
    # print(feature1.numpy())

    # 注意要使用copy()深拷贝 -> 防止原本的特征图被破坏
    bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)

    bn = nn.BatchNorm2d(2, eps=1e-5)  # 第一个参数是输入维度,第二参数是\epsilon(防止分母为0)
    output = bn(feature1)
    print(output)

2.3 使用BN时的注意事项

  1. 训练时要将training参数设置为True,在验证时将training参数设置为False。
    • 训练:model.train()
    • 验证/测试:model.eval()

    这是因为在训练时BN需通过正向传播不断统计均值和方差并更新这两个参数;同时也会通过反向传播对 均值调整值 γ \gamma γ 和方差调整值 β \beta β 进行学习和更新
    而在验证/测试时,并不需要统计和更新均值与方差,而是使用之前在训练时统计好的均值和方差以及二者的调整值进行BN,这样就可以实现训练和验证/测试数据都是同一分布

  2. batch size尽可能设置大点,设置小后表现可能很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差

    当BS=1时,BN是没有什么作用的,甚至效果可能会变差

  3. 建议将BN层放在卷积层(Conv)和激活层(例如ReLU)之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,即便使用了偏置bias求出的结果也是一样的: y i b = y i y_i^b = y_i yib=yi

    BN层放在卷积层(Conv)和激活层(例如ReLU)之间形成经典的三明治结构:Conv(without bias) -> BN -> Non-linear

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第9张图片

2.3 BN总结

  1. 均值 μ \mu μ 和方差 σ 2 \sigma^2 σ2正向传播过程中统计得到
  2. 均值调整值 β \beta β 和方差调整值 β \beta β反向传播过程中训练得到

3. 迁移学习

3.1 迁移学习的优势

  1. 能够快速训练出一个理想的结果

    如果我们从头开始训练一个模型,可能需要几十个epoch才能得到一个不错的结果,但如果使用迁移学习,可能只需要迭代2,3个epoch就可以得到理想的结果。
    迁移学习可以大大减少训练时间

  2. 当数据集较小时,也能训练出理想的结果

    如果网络特别大(网络可训练参数很多),如果数据集比较小,那么这个小的数据集是不足以训练整个网络的(很容易发生过拟合现象),最终的训练结果会非常糟糕;
    如果使用迁移学习,使用别人训练好的参数再去训练比较小的数据集,一般可以得到一个不错的结果。

Note: 使用他人的预训练模型参数时,要和其预处理方式一致,否则结果会很糟糕

3.2 迁移学习大体思想

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第10张图片
对于浅层的卷积层来说,学习到的角点信息、纹理信息一般是比较通用的,所以我们只需要训练后面几层或者分类头,快速学习新的、高维的数据特征,从而实现一个理想的效果。

3.3 迁移学习的方式

  1. 载入权重后训练所有参数
  2. 载入权重后只训练最后几层参数
  3. 载入权重后在原网络基础上再添加一层全连接层,仅训练最后一个全连接层

[学习笔记] ResNet,BN,以及迁移学习(附带TensorBoard可视化)_第11张图片

VGG-16网络结构

3.3.1 载入权重后训练所有参数

VGG-16是在imagenet上进行训练,分类结果为1000。在使用这种方式进行迁移学习时,需对最后的全连接层分类个数进行调整以满足自用数据集分类数。

因为修改了最后的FC层,所以最后的FC层参数无法载入

3.3.2 载入权重后只训练最后几层参数

一般是固定全连接层之前的所有模型参数(不进行反向传播和梯度更新),只训练几个FC层。

这样做的好处:

  1. 训练参数变少
  2. 训练速度加快

同样也需要修改最后的分类数,最后的FC层参数无法加载

3.3.3 载入权重后在原网络基础上再添加一层全连接层,仅训练最后一个全连接层

这样做最大的好处是:可以载入所有的模型参数

3.4 迁移学习使用场景

3.4.1 算力/训练时间受限

建议使用第2、3种方法。

3.4.2 算力/训练时间不受限(想要得到最优的结果)

建议采用第一种方法(效果比2、3种方法高),且比不使用迁移学习的方法要快。

4. 代码

4.1 ResNet模型代码

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    """
        For ResNet-18/34
    """
    expansion = 1  # Channel will be change in this block

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        # Conv (without bias) -> BN -> ReLU
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    """
    注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。
    但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2,
    这么做的好处是能够在top1上提升大概0.5%的准确率。
    可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch
    """
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,
                 groups=1, width_per_group=64):
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 blocks_num,
                 num_classes=1000,
                 include_top=True,
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.groups = groups
        self.width_per_group = width_per_group

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:  # 如果要进行下采样
            # 构造下采样层(虚线的identity)
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        # 构建第一个Block(只有第一个Block会进行下采样)
        layers.append(block(self.in_channel,
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        # 根据Block个数构建其他Block
        for _ in range(1, block_num):
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet18(num_classes=1000, include_top=True):
    # "https://download.pytorch.org/models/resnet18-f37072fd.pth"
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)


def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)


def resnet152(num_classes=1000, include_top=True):
    # "https://download.pytorch.org/models/resnet152-394f9c45.pth"
    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, include_top=include_top)


def resnext50_32x4d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)


def resnext101_32x8d(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
    groups = 32
    width_per_group = 8
    return ResNet(Bottleneck, [3, 4, 23, 3],
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,
                  width_per_group=width_per_group)

4.2 训练代码(附带TensorBoard可视化)

import os
import sys
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm

from model import resnet34
from torchvision.models import resnet
from torch.utils.tensorboard import SummaryWriter


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    """
    transforms.Resize()
        size (sequence or int): Desired output size. 
        
        If size is a sequence like (h, w), output size will be matched to this. 
        If size is an int, smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to (size * height / width, size)
    """
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),  # ① 先将最小边缩放到256(不是将图片缩放到256×256)
                                   transforms.CenterCrop(224),  # ② 缩放图片后再进行中心裁剪
                                   transforms.ToTensor(),  # ③ 将图片转换为tensor
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化(满足某一分布)
                                   ])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./pretrained/resnet34-b627a593.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    """
        net.fc -> 网络的全连接层
        net.fc.in_feature
            torch.nn.modules.linear.Linear 
            def __init__(self,
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 device: Any = None,
                 dtype: Any = None) -> None
        通过查看 Linear 的定义,我们发现它的输入参数为 in_features,所以我们可以调取它
    """
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)  # 重新定义网络的全连接层

    """
        以上是加载参数的官方提供的方法,即:
            1. 获取网络全连接层的输入
            2. 重新定义网络全连接层的输入和输出
            
        当然,除了这种方法,还有一种方法来实现,就是在加载参数字典的时候将字典中的全连接层参数删掉,这样就不会出现冲突了
    """

    net.to(device)

    # 在tb中添加 tensor 流动图
    dummy_input = torch.rand(6, 3, 224, 224).cuda()  # dummy: 一种对真实或原始物体的模仿,旨在用作实际的替代品
    tb.add_graph(net, dummy_input)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer
    params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=lr)

    best_acc = 0.0

    train_steps = len(train_loader)

    """
        在网络训练和验证/测试时,.train()和.eval()一定要写,因为在网络的不同状态下会有不同的行为
    """

    for epoch in range(epochs):
        # train
        net.train()
        train_loss = 0.0  # 一个epoch中的训练损失
        train_correct_num = 0  # 一个epoch中的训练预测的正确个数
        train_bar = tqdm(train_loader, file=sys.stdout)

        # step: iteration num -> batch
        # data: data:
        #           1. img.toTensor;
        #           2. label
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()

            # inference
            logits = net(images.to(device))  # return a batch_size result
            # print(f"logits.shape: {logits.shape}")  # torch.Size([16, 5])

            # 训练阶段正确预测个数
            train_correct_num += torch.eq(torch.max(logits, dim=1)[1], labels.to(device)).sum().item()

            # 通过损失函数计算损失
            loss = loss_function(logits, labels.to(device))
            loss.backward()  # 对损失进行反向传播
            optimizer.step()  # 参数更新

            # print statistics by tqdm
            train_loss += loss.item()  # 累加batch的损失
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)

        # validate
        net.eval()  # 声明模型状态
        val_correct_num = 0.0  # 一个epoch中的验证预测的正确个数
        val_loss = 0.0  # # 一个epoch中的验证集损失
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                val_loss += loss_function(outputs, val_labels.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                val_correct_num += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        # 计算训练、验证准确率
        train_accurate = train_correct_num / train_num
        val_accurate = val_correct_num / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' % (epoch + 1, train_loss / train_steps, val_accurate))

        # 使用tensorboard可视化训练过程
        tb.add_scalar("[train] Loss", train_loss, epoch + 1)  # +1 令epoch从1开始
        tb.add_scalar("[train] top-1 acc", train_accurate, epoch + 1)

        # 使用tensorboard可视化验证过程
        tb.add_scalar("[val] Loss", val_loss, epoch + 1)  # +1 令epoch从1开始
        tb.add_scalar("[val] top-1 acc", val_accurate, epoch + 1)

        tb.add_scalars("[Accuracy] val-train", {"val": val_accurate, "train": train_accurate}, epoch + 1)

        # 统计需要查看的参数直方图
        # tb.add_histogram("conv1.bias", net.conv1.bias, epoch + 1)
        # tb.add_histogram("conv1.weight", net.conv1.weight, epoch + 1)
        # tb.add_histogram("conv2.bias", net.conv2.bias, epoch + 1)
        # tb.add_histogram("conv2.weight", net.conv2.weight, epoch + 1)

        # 保存模型
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), os.path.join(result_path, model_save_name))
            print(f"model has been save in {os.path.join(result_path, model_save_name)}")

    print('Finished Training')


if __name__ == '__main__':
    """Hyper-param"""
    epochs = 30
    batch_size = 16
    lr = 0.0001
    model_save_name = 'resNet34.pth'

    result_path = f"{os.getcwd()}/res"
    if not os.path.exists(result_path):
        os.mkdir(result_path)
    tb = SummaryWriter(log_dir=result_path, flush_secs=3)
    print(f"tb_path: {result_path}")

    main()

4.3 预测代码

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         # 这里一定要和训练时使用的方法一致,不然模型并不能正确预测
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
         ])

    # load image
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)  # 使用pillow读取图片
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)  # 模型前向传播需要BS维度,这里是为了添加该维度

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
    """
        {
            "0": "daisy",
            "1": "dandelion",
            "2": "roses",
            "3": "sunflowers",
            "4": "tulips"
        }
    """

    with open(json_path, "r") as f:
        class_indict = json.load(f)
    # print(class_indict)  # {'0': 'daisy', '1': 'dandelion', '2': 'roses', '3': 'sunflowers', '4': 'tulips'}


    # create model
    model = resnet34(num_classes=nc).to(device)

    # load model weights
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # prediction
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()  # output -> list
        predict = torch.softmax(output, dim=0)  # 使用softmax获得这个列表元素的分数
        predict_cla = torch.argmax(predict).numpy()  # 求得上面list值最大的元素的index

    # json文件我们可以看成是一个dict,使用key取value
    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)

    # 打印每一个类别的概率
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    img_path = "exp_rose.jpg"  # 预测图片的路径
    weights_path = "./res/resNet34.pth"
    nc = 5
    main()

参考

  1. https://www.bilibili.com/video/BV1T7411T7wa?spm_id_from=333.999.0.0
  2. https://blog.csdn.net/qq_37541097/article/details/104434557
  3. https://www.bilibili.com/video/BV14E411H7Uw?spm_id_from=333.999.0.0

你可能感兴趣的:(PyTorch,机器学习,分类网络,迁移学习,深度学习,pytorch)