ResNet || 基于PyTorch的代码实现 + 迁移学习

文章目录

  • 1 class Basicblock
    • 1.0 expansions是什么
    • 1.1 输入的参数
    • 1.2 conv1层的参数
    • 1.3 forward函数
  • 2 class Bottleneck

写在前面:上一篇博文介绍了ResNet 、 Batch Normalization 、 迁移学习的原理,有兴趣的小伙伴可以一起学习!

这里是引用

1 class Basicblock

  • 这里对应的是18-layer、34-layer 的结构
    ResNet || 基于PyTorch的代码实现 + 迁移学习_第1张图片
class BasicBlock(nn.Module):
    expansion = 1  # 标识主分支上的kernel个数是否发生变化

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

1.0 expansions是什么

  • 如下图所示,在 18-layer 中,卷积核的个数是不变的,而在 50-layer 中, 卷积核的个数是变化的。
  • 不变代表的是 expansion = 1
    ResNet || 基于PyTorch的代码实现 + 迁移学习_第2张图片

1.1 输入的参数

  1. in_channel : 输入特征矩阵的 深度
  2. out_channel : 输出特征矩阵的 深度,对应的就是卷积核的个数
  3. stride : 1×1 的卷积核
  4. downsample : 下采样为None,就是定义的虚线部分,这个BasicBlock 要同时能够表示 实线的resnet 和虚线的ResNet
    ResNet || 基于PyTorch的代码实现 + 迁移学习_第3张图片

1.2 conv1层的参数

  1. in_channels : 输入特征矩阵的深度
  2. out_channels : 输出特征矩阵的深度
  3. kernel_size : 卷积核的尺寸
  4. stride : 步长
  5. padding : 填零
  6. bias : 一般是不需要偏置的,如果是使用了BN就不能使用bias了
  7. 其余的就类似

1.3 forward函数

  1. identity : 就是捷径上的输出值,首先赋值为x
  2. 然后就是先经过一个conv1卷积,然后BN,然后relu,然后conv2卷积,然后BN,然后相加,然后relu输出

2 class Bottleneck

你可能感兴趣的:(机器学习算法,pytorch,迁移学习,深度学习)