【Pytorch】densenet网络结构复现及预训练权重加载

前言

最近在自己搭建densenet的网络结构,论文的原文中对网络结构的具体参数并没有很详尽的描述,所以只能参照作者给出的github代码自己搭建。
论文地址:https://arxiv.org/abs/1608.06993
论文github链接:https://github.com/liuzhuang13/DenseNet

概述

论文中的Densenet的网络结构可以大体分为4个部分:一开始的特征提取部分,DenseBlock,Transition和最后的分类部分。
【Pytorch】densenet网络结构复现及预训练权重加载_第1张图片

头尾的两个部分都是CNN常规的部分;DenseBlock就是密集连接的卷积层组成的小块,可以通过改变DenseBlock包含的卷积层的层数来改变网络的深度;Transition则主要实现特征图的压缩(减少通道的数量),两个DenseBlock之间的就是Transition。

Denselayer

Denselayer是DenseBlock中的每一层,这部分可能包含原文作者提到的Bottleneck layers。
所谓的Bottleneck layers是指在每个3X3卷积层之前添加1X1的卷积层用来减少输入的特征图的数量以提高计算效率,如不包含Bottleneck layers,则Denselayer只包含3X3的卷积层。

class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, BatchNorm):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', BatchNorm(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                                           growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', BatchNorm(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)

代码中的growth_rate只的是每层Denselayer的输出的通道数。

Dense Block

Dense Block就是将每层Dneselayer的输入都设置成在该层之前的所有层的输出的堆叠。

class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, BatchNorm):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate, BatchNorm)
            self.add_module('denselayer%d' % (i + 1), layer)

Transition

Transition主要实现特征图的压缩(减少通道的数量),这里采用的是论文中提出的结构,在3X3卷积层后面加入1X1的卷积层和2X2的average pooling。同时通过num_input_features和num_output_features的设置可以实现不同的压缩比例。

class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features, BatchNorm):
        super(_Transition, self).__init__()
        self.add_module('norm', BatchNorm(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))

整体DenseNet

参数设置

参考预训练的模型的参数,预训练的模型有Densenet121,Densenet169,Densenet201,Densenet161。
每个预训练模型的growth_rate, num_init_features,bn_size, drop_rate, transition_num都是一样的,而num_layers是不同的。
对于以上四个模型,都包含了4个DenseBlock 和3个Transition,所以num_layers是有4个值,表示每个DenseBlock包含的DenseLayer 的数量;transition_num设为3,表示有3个Transition。
num_layers的具体设置和其他参数的意义可以看下面代码的注释。

代码

class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" `_
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        num_layers (tuple of 4 ints) - how many layers in each pooling block  ---121-(6,12,24,16)  169-(6,12,32,32)  201-(6,12,48,32)  161-(6,12,36,24)
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        trainstion_num (int) - number of transition module  
    """

    def __init__(self,
                 BatchNorm,
                 growth_rate=32,
                 num_init_features=64,
                 bn_size=4,
                 drop_rate=0.2,
                 num_layers=(6, 12, 24, 16),
                 transition_num=3,):

        super(DenseNet, self).__init__()

        # Low_feature 1/4 size
        self.low_feature = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),#
            ('norm0', BatchNorm(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        ]))
        # Middle_feature 1/16 size
        self.middle_feature = nn.Sequential()
        self.end_feature = nn.Sequential()
        num_features = num_init_features
        for i, num in enumerate(num_layers):
            if i < 2:
                bolck = _DenseBlock(num_layers=num, num_input_features=num_features,
                             bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, BatchNorm=BatchNorm)
                num_features = num_features + num * growth_rate
                self.middle_feature.add_module('denseblock{}'.format(str(i+1)), bolck)
                if i < transition_num:
                    trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, BatchNorm=BatchNorm)
                    num_features = num_features // 2
                    self.middle_feature.add_module('transition{}'.format(str(i+1)), trans)
            else:
                bolck = _DenseBlock(num_layers=num, num_input_features=num_features,
                                    bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, BatchNorm=BatchNorm)
                num_features = num_features + num * growth_rate
                self.end_feature.add_module('denseblock{}'.format(str(i+1)), bolck)
                if i < transition_num:
                    trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, BatchNorm=BatchNorm)
                    num_features = num_features // 2
                    self.end_feature.add_module('transition{}'.format(str(i+1)), trans)
        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        low_feature = self.low_feature(x)
        middle_feature = self.middle_feature(low_feature)
        end_feature = self.end_feature(middle_feature)
        out = F.relu(end_feature, inplace=True)
        out = F.avg_pool2d(out, kernel_size=8, stride=1).view(end_feature.size(0), -1)
        return low_feature, middle_feature, end_feature, out

这里由于我需要将中间的某些特征抽取出来,所以定义了low_feature, middle_feature, end_feature,如果只是需要最后的输出结果的话,可以只定义一个self.features来添加所有的结构。

预训练模型参数的加载

定义完上面的网络,其实我们就可以直接训练网络了,但是可能效果不好,所以一般会选择先加载预训练模型的参数再进行训练。

加载参数代码

'''首先要有下载权重的url'''
model_urls = {
    'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
}
'''之后就加载参数'''
def densenet161(BatchNorm, pretrained=True):
    model = DenseNet(BatchNorm,
                     growth_rate=32,
                     num_init_features=64,
                     bn_size=4,
                     drop_rate=0.2,
                     num_layers=(6, 12, 36, 24),
                     transition_num=3)
    if pretrained:
        pretrained = model_zoo.load_url(model_urls['densenet161'])
        model.load_state_dict(pretrained)
    return model

测试加载代码

model = densenet169(BatchNorm=nn.BatchNorm2d)
input = torch.rand(1, 3, 512, 512)
low_feature, middle_feature, end_feature, out = model(input)
print(low_feature.size())
print(middle_feature.size())
print(end_feature.size())

简单的来说,利用上面的代码就可以将预训练模型的参数加载到自己搭建的模型上面,但是可能会出现一个预训练模型的state_dict键(key)名与自己搭建的网络不对应的情况。
在这里插入图片描述
这里其实就是一个字符串处理的问题,将加载的state_dict的键(key)名改到与自己的网络一致即可。
根据我上面搭建的网络的结构,可以采用以下代码实现。

def densenet169(BatchNorm, pretrained=True):
    model = DenseNet(BatchNorm,
                     growth_rate=32,
                     num_init_features=64,
                     bn_size=4,
                     drop_rate=0.2,
                     num_layers=(6, 12, 32, 32),
                     transition_num=3)
    if pretrained:
        pretrained = model_zoo.load_url(model_urls['densenet169'])
        del pretrained['classifier.weight']
        del pretrained['classifier.bias']
        del pretrained['features.norm5.weight']
        del pretrained['features.norm5.bias']
        del pretrained['features.norm5.running_mean']
        del pretrained['features.norm5.running_var']
        new_state_dict = OrderedDict()
        new_state_dict2 = OrderedDict()
        blockstr = 'denseblock'
        transstr = 'transition'
        for k, v in pretrained.items():
            name = k.replace('features', 'low_feature')
            new_state_dict[name] = v
        for k, v in new_state_dict.items():
            name = k
            for i in range(4):
                if i < 2:
                    if blockstr + str(i + 1) in name:
                        name = name.replace('low_feature', 'middle_feature')
                        name = name.replace('conv.', 'conv')
                        name = name.replace('norm.', 'norm')
                    elif transstr + str(i + 1) in name:
                        name = name.replace('low_feature', 'middle_feature')
                else:
                    if blockstr + str(i + 1) in name:
                        name = name.replace('low_feature', 'end_feature')
                        name = name.replace('conv.', 'conv')
                        name = name.replace('norm.', 'norm')
                    elif transstr + '3' in name:
                        name = name.replace('low_feature', 'end_feature')
            new_state_dict2[name] = v
        model.load_state_dict(new_state_dict)
    return model

修改之后再运行上面的测试代码即可得到如下的结果,这样预训练参数的加载也完成了。
【Pytorch】densenet网络结构复现及预训练权重加载_第2张图片

结语

代码是参考作者github上面的相关文件和网上的相关内容的,其中的逻辑处理是自己写的,可能并不完美,算是自己学习的一个记录吧。
最后附上完整代码,https://github.com/Hayz2087/densenet

参考资料

https://github.com/baldassarreFe/pytorch-densenet-tiramisu
https://www.cnblogs.com/ywheunji/p/10605614.html

你可能感兴趣的:(python,深度学习,神经网络)