>-**本文为[365天深度学习训练营](https://mp.weixin.qq.com/s/ 2Mc0B5c2SdivAR3WS_g1bA)中的学习记录博客**
>-**原作者:[k同学啊|接辅导、项目定制](https:/ /mtyjkh.blog.csdn.net/)*
本周任务:
1. 根据pytorch代码编写tensorflow代码(本文采用另一种pytorch代码实现,与源码不同)
2.了解DenseNet与ResNetV的区别
3.改进思路是否可以迁移
DenseNet提出了更为激进的密集连接机制:每个层都会接受前面所有层作为其额外输入。
DenseNet主要由四部分组成
1.7 * 7的Conv:下采样
2.DenseBlock:特征提取
3.Transition:下采样
4.Classification:分类
因此编写算法按照上述模块分别编写堆叠即可
实现方式按照DenseNetBlock->Transition layer->DenseNet方式实现
DenseBlock按照BN+ReLU+Conv+BN+ReLU+Conv的方式实现,其中第一个卷积用于减少计算量,改变通道数,第二个卷积用于特征提取,增加感受野,因此二者卷积核大小不同。
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plot
import torch.nn.functional as F
class con(nn.Module):
def __init__(self, inf, gr, bn, k=1):
super(con, self).__init__()
out = bn * gr
self.bn = nn.BatchNorm2d(inf)
self.act = nn.ReLU()
self.conv = nn.Conv2d(inf, out, kernel_size=k, stride=1, bias=False)
def forward(self, x):
out = self.bn(x)
out = self.act(out)
out = self.conv(out)
return out
通过给定相应参数打印网络结构查看
net = con(3, 4, 4)
print(net)
结果如下:
drop_rate:用于随机失活神经元(只在训练模式中使用,验证时关闭)(可加可不加)
torch.cat[(x, new), 1]:在channel维度上将经过Denselayer的结果与原结果拼接,1代表了channel维度([bn, c, h, w])
gr:特征增长率
class DenseLayer(nn.Module):
def __init__(self, inf, gr, drop_rate, bn=4):
super(DenseLayer, self).__init__()
_ = 1
self.conv1 = con(inf, gr, bn)
self.conv2 = con(bn * gr, gr, _, k=3)
self.drop = drop_rate
def forward(self, x):
new = self.conv1(x)
new = self.conv2(new)
if self.drop > 0:
new = F.dropout(new, p=self.drop_rate, trasning=self.training)
return torch.cat[(x, new), 1]
通过给定相应参数打印网络结构查看
net = DenseLayer(3, gr=4, drop_rate=1)
print(net)
结果如下:
每一个DenseBlock中都会重复多次堆叠Denselayer,后一个Denselayer的输入特征与前一个的关系为:前一个输入特征 + i*特征增长率
class DenseBlock(nn.Module):
def __init__(self, num_layers, num_inf, drop_rate, gr, bn=4):
super(DenseBlock, self).__init__()
for i in range(num_layers):
layer = DenseLayer(num_inf + i*gr, gr, drop_rate, bn=bn)
self.add_module("denselayer%d"%(i+1,), layer)
def forward(self, init_features):
features = [init_features]
for name, layer in self.named_children():
new_features = layer(*features)
features.append(new_features)
return torch.cat(features, 1)
为了方便理解可将模型结构打印
n = DenseBlock(4, 3, 1, 4)
print(n)
结果如下(只显示了前两个Denselayer)
Transition模块用于下采样,所以非常简单,可以利用前面定义好的con模块,也可以重新描述。
class Transition(nn.Sequential):
def __init__(self, num_in, num_out):
super(Transition, self).__init__()
self.add_module("norm", nn.BatchNorm2d(num_in))
self.add_module("relu", nn.ReLU())
self.add_module("conv", nn.Conv2d(num_in, num_out, kernel_size=1, stride=1, padding=1, bias=False))
self.add_module("pool", nn.AvgPool2d(2, stride=2))
Net = Transition(3, 32)
print(Net)
打印结果如下:
DenseNet网络结构如下,注意要在每一个DenseBlock后面衔接一个Transition,总共插入的个数为DenseBlock的个数-1
class DenseNet(nn.Module):
def __init__(self, init_feature=64, block_setting=(6, 12, 24, 6),
drop_rate=0, gr=64, bn=4, compression_rate=0.5, num_classes=10):
"""
init_feature:初始输出channel(DenseBlock的初始输入channel)
block_setting:DenseBlock的个数
drop_rate:随机失活神经元
gr:每一层Block后的特征增长率
bn:batch_size
compressino_rate:参数缩减率
"""
super(DenseNet, self).__init__()
# First Conv2d
self.feature = nn.Sequential(
nn.Conv2d(3, init_feature, kernel_size=7, stride=2, padding=1, bias=False),
nn.BatchNorm2d(init_feature),
nn.ReLU(),
nn.MaxPool2d(3, stride=2, padding=1)
)
# DenseBlock
num_feature = init_feature
for i ,num_layers in enumerate(block_setting):
block = DenseBlock(
num_layers = num_layers,
num_inf = num_feature,
drop_rate = drop_rate,
gr = gr,
bn = bn
)
self.feature.add_module('denseblock%d' % (i + 1), block)
num_feature += num_layers * gr
# 插入Transition
if i != len(block_setting) - 1:
transition = Transition(num_feature, int(num_feature*compression_rate))
# Final BN+ReLU
self.final = nn.Sequential(
nn.BatchNorm2d(num_feature),
nn.ReLU())
# Classification layer
self.classifier = nn.Linear(num_feature, num_classes)
def forward(self, x):
features = self.feature(x)
out = F.avg_pool2d(features, 7, stride=1).view(features.size(0), -1)
out = self.classifier(out)
return out
打印网络结构查看详细信息(通过打印的信息与源码打印的信息对比就可以查看自己编写代码的正确性,不需要通过训练验证)
n = DenseNet()
print(n)
结果如下(部分):
通过torchinfo查看参数量
from torchinfo import summary
model = n
summary(model)