J3-DenseNet算法实战与解析

>-**本文为[365天深度学习训练营](https://mp.weixin.qq.com/s/ 2Mc0B5c2SdivAR3WS_g1bA)中的学习记录博客**

>-**原作者:[k同学啊|接辅导、项目定制](https:/ /mtyjkh.blog.csdn.net/)*

本周任务:

1. 根据pytorch代码编写tensorflow代码(本文采用另一种pytorch代码实现,与源码不同)

2.了解DenseNet与ResNetV的区别

3.改进思路是否可以迁移

一、设计理念

DenseNet提出了更为激进的密集连接机制:每个层都会接受前面所有层作为其额外输入。

J3-DenseNet算法实战与解析_第1张图片

 二、网络结构

J3-DenseNet算法实战与解析_第2张图片

DenseNet主要由四部分组成

1.7 * 7的Conv:下采样

2.DenseBlock:特征提取

3.Transition:下采样

4.Classification:分类

因此编写算法按照上述模块分别编写堆叠即可

三、代码实现及部分代码解释

实现方式按照DenseNetBlock->Transition layer->DenseNet方式实现

J3-DenseNet算法实战与解析_第3张图片

1. DenseBlock

DenseBlock按照BN+ReLU+Conv+BN+ReLU+Conv的方式实现,其中第一个卷积用于减少计算量,改变通道数,第二个卷积用于特征提取,增加感受野,因此二者卷积核大小不同。

定义卷积计算模块

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plot
import torch.nn.functional as F


class con(nn.Module):
    def __init__(self, inf, gr, bn, k=1):
        super(con, self).__init__()
        out = bn * gr
        self.bn = nn.BatchNorm2d(inf)
        self.act = nn.ReLU()
        self.conv = nn.Conv2d(inf, out, kernel_size=k, stride=1, bias=False)
        
    def forward(self, x):
        out = self.bn(x)
        out = self.act(out)
        out = self.conv(out)
        return out

通过给定相应参数打印网络结构查看

net = con(3, 4, 4)
print(net)

结果如下:

J3-DenseNet算法实战与解析_第4张图片

定义Denselayer

drop_rate:用于随机失活神经元(只在训练模式中使用,验证时关闭)(可加可不加)

torch.cat[(x, new), 1]:在channel维度上将经过Denselayer的结果与原结果拼接,1代表了channel维度([bn, c, h, w])

gr:特征增长率

class DenseLayer(nn.Module):
    def __init__(self, inf, gr, drop_rate, bn=4):
        super(DenseLayer, self).__init__()
        _ = 1
        self.conv1 = con(inf, gr, bn)
        self.conv2 = con(bn * gr, gr, _, k=3)
        self.drop = drop_rate
        
    def forward(self, x):
        new = self.conv1(x)
        new = self.conv2(new)
        if self.drop > 0:
            new = F.dropout(new, p=self.drop_rate, trasning=self.training)
        return torch.cat[(x, new), 1]

通过给定相应参数打印网络结构查看

net = DenseLayer(3, gr=4, drop_rate=1)
print(net)

结果如下:

J3-DenseNet算法实战与解析_第5张图片

定义DenseBlock

每一个DenseBlock中都会重复多次堆叠Denselayer,后一个Denselayer的输入特征与前一个的关系为:前一个输入特征 + i*特征增长率

class DenseBlock(nn.Module):
    def __init__(self, num_layers, num_inf, drop_rate, gr, bn=4):
        super(DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = DenseLayer(num_inf + i*gr, gr, drop_rate, bn=bn)
            self.add_module("denselayer%d"%(i+1,), layer)
 
    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.named_children():
            new_features = layer(*features)
            features.append(new_features)
        return torch.cat(features, 1)

为了方便理解可将模型结构打印

n = DenseBlock(4, 3, 1, 4)
print(n)

结果如下(只显示了前两个Denselayer)

J3-DenseNet算法实战与解析_第6张图片

2. 定义Transition

Transition模块用于下采样,所以非常简单,可以利用前面定义好的con模块,也可以重新描述。

class Transition(nn.Sequential):
    def __init__(self, num_in, num_out):
        super(Transition, self).__init__()
        self.add_module("norm", nn.BatchNorm2d(num_in))
        self.add_module("relu", nn.ReLU())
        self.add_module("conv", nn.Conv2d(num_in, num_out, kernel_size=1, stride=1, padding=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(2, stride=2))
        
Net = Transition(3, 32)
print(Net)

打印结果如下:

J3-DenseNet算法实战与解析_第7张图片

3. 搭建DenseNet网络

DenseNet网络结构如下,注意要在每一个DenseBlock后面衔接一个Transition,总共插入的个数为DenseBlock的个数-1

J3-DenseNet算法实战与解析_第8张图片

class DenseNet(nn.Module):
    def __init__(self, init_feature=64, block_setting=(6, 12, 24, 6),
                 drop_rate=0, gr=64, bn=4, compression_rate=0.5, num_classes=10):
        """
        init_feature:初始输出channel(DenseBlock的初始输入channel)
        block_setting:DenseBlock的个数
        drop_rate:随机失活神经元
        gr:每一层Block后的特征增长率
        bn:batch_size
        compressino_rate:参数缩减率
        """
        super(DenseNet, self).__init__()
        # First Conv2d
        self.feature = nn.Sequential(
            nn.Conv2d(3, init_feature, kernel_size=7, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(init_feature),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1)
            )
        
        # DenseBlock
        num_feature = init_feature
        for i ,num_layers in enumerate(block_setting):
            block = DenseBlock(
                num_layers = num_layers,
                num_inf = num_feature,
                drop_rate = drop_rate,
                gr = gr,
                bn = bn
                )
            self.feature.add_module('denseblock%d' % (i + 1), block)
            num_feature += num_layers * gr
            # 插入Transition
            if i != len(block_setting) - 1:
                transition = Transition(num_feature, int(num_feature*compression_rate))
                
        # Final BN+ReLU
        self.final = nn.Sequential(
        nn.BatchNorm2d(num_feature),
        nn.ReLU())
        
        # Classification layer
        self.classifier = nn.Linear(num_feature, num_classes)
        
    def forward(self, x):
        features = self.feature(x)
        out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
        out = self.classifier(out)
        return out

打印网络结构查看详细信息(通过打印的信息与源码打印的信息对比就可以查看自己编写代码的正确性,不需要通过训练验证)

n = DenseNet()
print(n)

结果如下(部分):

J3-DenseNet算法实战与解析_第9张图片

 通过torchinfo查看参数量

from torchinfo import summary

model = n

summary(model)

 

你可能感兴趣的:(深度学习,pytorch,人工智能)