DenseNet Pytorch实现

DenseNet网络实现

DenseNet和ResNet不同在于ResNet是跨层求和,而DenseNet是跨层将特征在通道维度进行拼接,下图一是ResNet,二是DenseNet。
DenseNet Pytorch实现_第1张图片
因为实在通道维度进行特征的拼接,所以底层的输出会保留进入后面的曾,这样能更好的保证梯度的传播,同时能够使用低维的特征和高维的特征进行联合训练,能够得到更好的结果。

DenseNet主要有Dense block组成,使用pytorch实现如下

def conv_block(in_channel,	out_channel):
				layer	=	nn.Sequential(
								nn.BatchNorm2d(in_channel),
								nn.ReLU(True),
								nn.Conv2d(in_channel,	out_channel,	3,	padding=1,	bias=False)
				)
				return layer


class dense_block(nn.Module):
    def __init__(self, in_channel, growth_rate, num_layers):
        super(dense_block, self).__init__()
        block = []
        channel = in_channel
        for i in range(num_layers):
            block.append(conv_block(channel, growth_rate))
            channel += growth_rate

        self.net = nn.Sequential(*block)

    def forward(self, x):
        for layer in self.net:
            out = layer(x)
            x = torch.cat((out, x), dim=1)
        return x

定义DenseNet:

class densenet(nn.Module):
    def __init__(self, in_channel, num_classes, growth_rate=32, block_layers=[6, 12,
                                                                              24, 16]):
        super(densenet, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channel, 64, 7, 2, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(3, 2, padding=1)
        )

        channels = 64
        block = []
        for i, layers in enumerate(block_layers):
            block.append(dense_block(channels, growth_rate, layers))
            channels += layers * growth_rate
            if i != len(block_layers) - 1:
                block.append(transition(channels, channels // 2))  # ᭗ᬦ	transition	੶


    channels = channels // 2
    self.block2 = nn.Sequential(*block)
    self.block2.add_module('bn', nn.BatchNorm2d(channels))
    self.block2.add_module('relu', nn.ReLU(True))
    self.block2.add_module('avg_pool', nn.AvgPool2d(3))
    self.classifier = nn.Linear(channels, num_classes)
    
def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)

    x = x.view(x.shape[0], -1)
    x = self.classifier(x)
    return x

你可能感兴趣的:(卷积神经网络专题)