编码器:编码器部分主要由普通卷积层和下采样层将特征图尺寸缩小,使其成为更低微的表征。目的是提取更多低级特征和低级特征,从而利用提取到的空间信息和全局信息精确分割。
解码器:主要是普通卷积、上采样层和融合层组成。利用上采样操作逐步恢复空间维度,融合编码过程中提取到的特征,在尽可能减少信息损失的前提下完成同尺寸的输出。
当一个复杂的前馈神经网络被训练在小的数据集时,容易造成过拟合,为了防止过拟合,可以通过阻止特征检测器的共同作用来提高网络性能。Dropout可以作为训练深度神经网络的一种技巧供选择。在每个批次训练中,可以忽略一半的特征检测器,可以明显的减少过拟合现象,这种方式可以减少特征检测器间的互相作用。
编码器中的每一个最大池化层的索引都存储了起来,用于之后在解码器中使用那些存储的索引来对应特征图进行去池化操作,这有助于保持高频信息的完整性,当但对地分辨图进行反池化时,他也会忽略临近信息。
近期的许多语义分割研究采用DNN,但是结果比较粗糙,主要原因是max-pooling和sub-sampling降低了特征图的分辨率,道路场景理解需要算法具有appearance外形、shape形状和理解空间关系(上下文)的能力。由于是道路场景,因此需要网络能够产生光滑的分割,网络也必须有能力勾画出小尺寸的物体,因此在提取图片特征过程中保留边界信息很重要。重用max-pooling indices的优点,提高边界够画,减少了进行端到端训练的参数,这种上采样形式可以被集成到任何encode-decode架构的网络中。
导入需要用到的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
这里是用VGG16训练的,从models里面调用
vgg16_pretrained = models.vgg16(pretrained=False)
因为整个网络的解码部分是通过对VGG中对应编码部分池化做的反池化,所以需要将编码部分池化的索引拿到
pool_list = [4,9,16,23,30]
for index in pool_list:
vgg16_pretrained.features[index].return_indices = True
在解码部分,里面相同的模块我们可以用函数来包起来
def decode(input_channel,out_channel,num=3):
if num == 3:
decode_boby = nn.Sequential(
nn.Conv2d(input_channel,input_channel,3,padding=1),
nn.Conv2d(input_channel,input_channel,3,padding=1),
nn.Conv2d(input_channel,out_channel,3,padding=1)
)
if num == 2:
decode_boby = nn.Sequential(
nn.Conv2d(input_channel,input_channel,3,padding=1),
nn.Conv2d(input_channel,out_channel,3,padding=1)
)
return decode_boby
网络部分 没有加softmax这部分,读者可以自行加上
self.encode1 = vgg16_pretrained.features[:4]
self.pool1 = vgg16_pretrained.features[4]
self.encode2 = vgg16_pretrained.features[5:9]
self.pool2 = vgg16_pretrained.features[9]
self.encode3 = vgg16_pretrained.features[10:16]
self.pool3 = vgg16_pretrained.features[16]
self.encode4 = vgg16_pretrained.features[17:23]
self.pool4 = vgg16_pretrained.features[23]
self.encode5 = vgg16_pretrained.features[24:30]
self.pool5 = vgg16_pretrained.features[30]
self.decode5 = decode(512,512)
self.uppool5 = nn.MaxUnpool2d(2,2)
self.decode4 = decode(512,256)
self.uppool4 = nn.MaxUnpool2d(2,2)
self.decode3 = decode(256,128)
self.uppool3 = nn.MaxUnpool2d(2,2)
self.decode2 = decode(128,64,2)
self.uppool2 = nn.MaxUnpool2d(2,2)
self.decode1 = decode(64,12,2)
self.uppool1 = nn.MaxUnpool2d(2,2)
正向传播部分
encode1 = self.encode1(x)
output_size1 = encode1.size()
pool1,indices1 = self.pool1(encode1)
encode2 = self.encode2(pool1)
output_size2 = encode2.size()
pool2,indices2 = self.pool2(encode2)
encode3 = self.encode3(pool2)
output_size3 = encode3.size()
pool3,indices3 = self.pool3(encode3)
encode4 = self.encode4(pool3)
output_size4 = encode4.size()
pool4,indices4 = self.pool4(encode4)
encode5 = self.encode5(pool4)
output_size5 = encode5.size()
pool5,indices5 = self.pool5(encode5)
uppool5 = self.uppool5(pool5,indices5,output_size5)
decode5 = self.decode5(uppool5)
print(decode5.size())
uppool4 = self.uppool4(decode5,indices4,output_size4)
decode4 = self.decode4(uppool4)
print(decode4.size())
uppool3 = self.uppool3(decode4,indices3,output_size3)
decode3 = self.decode3(uppool3)
print("3:",decode3.size())
uppool2 = self.uppool2(decode3,indices2,output_size2)
print(uppool2.size())
decode2 = self.decode2(uppool2)
uppool1 = self.uppool1(decode2,indices1,output_size1)
decode1 = self.decode1(uppool1)
代码测试
img = torch.rand((1,3,480,480))
img.size()
SegNet = VGG16_SegNet()
SegNet(img)
结果:
input_img torch.Size([1, 3, 480, 480])
torch.Size([1, 512, 30, 30])
torch.Size([1, 256, 60, 60])
3: torch.Size([1, 128, 120, 120])
torch.Size([1, 128, 240, 240])
torch.Size([1, 12, 480, 480])