经典神经网络 -- DenseNet : 设计原理与pytorch实现

原理

概念与网络结构

       DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection)

       DenseNet的一大特色是通过 特征在channel上的连接 来实现特征重用(feature reuse)

       DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能

       相比ResNet,DenseNet提出了一个更激进的密集连接机制:即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。

       ResNet是每个层与前面的某层(一般是2~3层)短路连接在一起,连接方式是通过元素级相加,而在DenseNet中,每个层都会与前面所有层在channel维度上连接(concat)在一起,并作为下一层的输入。

       对于一个 L 层的网络,DenseNet共包含 L*(L+1)/2 个连接,相比ResNet,这是一种密集连接。而且DenseNet是直接concat来自不同层的特征图,这可以实现特征重用,提升效率,这一特点是DenseNet与ResNet最主要的区别。

       CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的大小,而DenseNet的密集连接方式需要特征图大小保持一致。但是一层层链接下去肯定会越来越大,为了解决这个问题,DenseNet网络中使用DenseBlock+Transition的结构,这是一种折中组合。

       其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition模块是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低,还可以压缩模型。

       DenseBlock中的非线性组合函数 H 采用的是BN+ReLU+3x3 Conv的结构。与ResNet不同,所有DenseBlock中各个层卷积之后均输出 k 个特征图,即得到的特征图的out_channel数为 k ,或者说采用 k 个卷积核。k 在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的 k (比如12),就可以得到较佳的性能。由于后面层的输入会非常大,DenseBlock内部可以采用bottleneck层来减少计算量,主要是原有的结构中增加1x1 Conv,降低特征数量

       Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。

       DenseNet-C 和 DenseNet-BC

特点

经典神经网络 -- DenseNet : 设计原理与pytorch实现_第1张图片

 

代码实现

# DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection)

# DenseNet的一大特色是通过 特征在channel上的连接 来实现特征重用(feature reuse)

# DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能

# 相比ResNet,DenseNet提出了一个更激进的密集连接机制:即互相连接所有的层,
# 具体来说就是每个层都会接受其前面所有层作为其额外的输入。

# ResNet是每个层与前面的某层(一般是2~3层)短路连接在一起,连接方式是通过元素级相加,
# 而在DenseNet中,每个层都会与前面所有层在channel维度上连接(concat)在一起,并作为下一层的输入。

# 对于一个 L 层的网络,DenseNet共包含 L*(L+1)/2 个连接,相比ResNet,这是一种密集连接。
# 而且DenseNet是直接concat来自不同层的特征图,这可以实现特征重用,提升效率,这一特点是DenseNet与ResNet最主要的区别。

# CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的大小,而DenseNet的密集连接方式需要特征图大小保持一致。
# 但是一层层链接下去肯定会越来越大,为了解决这个问题,DenseNet网络中使用DenseBlock+Transition的结构,这是一种折中组合。
# 其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。
# 而Transition模块是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低,还可以压缩模型。

# DenseBlock中的非线性组合函数 H 采用的是BN+ReLU+3x3 Conv的结构。
# 与ResNet不同,所有DenseBlock中各个层卷积之后均输出 k 个特征图,即得到的特征图的out_channel数为 k ,或者说采用 k 个卷积核。 
# k 在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的 k (比如12),就可以得到较佳的性能。
# 由于后面层的输入会非常大,DenseBlock内部可以采用bottleneck层来减少计算量,主要是原有的结构中增加1x1 Conv,降低特征数量

# Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1 Conv+2x2 AvgPooling。
# DenseNet-C 和 DenseNet-BC


from turtle import forward, shape
from numpy import block
import torch
import torch.nn as nn

from densenet import transition


def conv_block(in_channel, out_channel): # 一个卷积块
    layer = nn.Sequential(
        nn.BatchNorm2d(in_channel),
        nn.ReLU(),
        nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=False)
    )
    return layer


class dense_block(nn.Module):
    def __init__(self, in_channel, growth_rate, num_layers):
        super().__init__() # growth_rate => k => out_channel
        block = []
        channel = in_channel # channel => in_channel
        for i in range(num_layers):
            block.append(conv_block(channel, growth_rate))
            channel += growth_rate # 连接每层的特征
        self.net = nn.Sequential(*block) # 实现简单的顺序连接模型 
        # 必须确保前一个模块的输出大小和下一个模块的输入大小是一致的
    
    def forward(self, x):
        for layer in self.net:
            out = layer(x)
            x = torch.cat((out, x), dim=1) # contact同维度拼接特征,stack(是把list扩维连接
            # torch.cat()是为了把多个tensor进行拼接,在给定维度上对输入的张量序列seq 进行连接操作
            # inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列
            # dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列
        return x


def trabsition(in_channel, out_channel):
    trans_layer = nn.Sequential(
        nn.BatchNorm2d(in_channel),
        nn.ReLU(),
        nn.Conv2d(in_channel, out_channel, 1), # kernel_size = 1 1x1 conv
        nn.AvgPool2d(2, 2) # 2x2 pool
    )
    return trans_layer


class DenseNet121(nn.Module):
    def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12, 24, 16]):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channel, 64, 7, 2, 3), # padding=3 参数要熟悉
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(3, 2, padding=1)
        )
        self.DB1 = self._make_dense_block(64, growth_rate, num=block_layers[0])
        self.TL1 = self._make_transition_layer(256)
        self.DB2 = self._make_dense_block(128, growth_rate, num=block_layers[1])
        self.TL2 = self._make_transition_layer(512)
        self.DB3 = self._make_dense_block(256, growth_rate, num=block_layers[2])
        self.TL3 = self._make_transition_layer(1024)
        self.DB4 = self._make_dense_block(512, growth_rate, num=block_layers[3])
        self.global_avgpool = nn.Sequential( # 全局平均池化
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(1024, num_classes) # fc层

    def forward(self, x):
        x = self.block1(x)
        x = self.DB1(x)
        x = self.TL1(x)
        x = self.DB2(x)
        x = self.TL2(x)
        x = self.DB3(x)
        x = self.TL3(x)
        x = self.DB4(x)
        x = self.global_avgpool(x)

    def _make_dense_block(self, channels, growth_rate, num): # num是块的个数
        block = []
        block.append(dense_block(channels, growth_rate, num))
        channels += num * growth_rate # 特征变化 # 这里记录下即可,生成时dense_block()中也做了变化
        return nn.Sequential(*block)
    
    def _make_transition_layer(self, channels):
        block = []
        block.append(transition(channels, channels//2)) # channels // 2就是为了降低复杂度 θ = 0.5
        return nn.Sequential(*block)


if __name__ == '__main__':
    net = DenseNet121(3, 10) # in_channel, num_classes
    x = torch.rand((1, 3, 224, 224))
    for name,layer in net.named_children():
        if name != 'classifier':
            x = layer(x)
            print(name, 'output shape:', x,shape)
        else:
            print(x.shape)
            x = x.view(x.shape[0], -1) # 展开tensor 分类
            print(x.shape)
            x = layer(x)
            print(name, 'output shape:', x.shape)

参考文章:

DenseNet:比ResNet更优的CNN模型 - 知乎

pytorch 实现Densenet模型 代码详解,计算过程,_视觉盛宴的博客-CSDN博客_densenet pytorch

你可能感兴趣的:(求职,CV-计算机视觉,pytorch,深度学习,算法,人工智能,计算机视觉)