论文全称:《SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation》
论文地址:https://arxiv.org/abs/1511.00561
论文代码:
python pytorch版本https://github.com/delta-onera/segnet_pytorch
python TensorFlow版本https://github.com/tkuanlun350/Tensorflow-SegNet
C++ caffe版本https://github.com/alexgkendall/caffe-segnet
论文demo:http://mi.eng.cam.ac.uk/projects/segnet/demo.php#demo
目录
论文综述
网络结构
Max pooling 索引
代码详解
SegNet的新颖之处在于解码器对其低分辨率输入特征图进行上采样的方式。具体来说,解码器使用了在对应编码器的max-pooling步骤中所计算出的pooling索引来执行非线性上采样。这就消除了学习upsample的需要。
SegNet主要由场景理解应用程序驱动。因此,它被设计成在推理过程中在内存和计算时间方面都是有效的。与其他竞争体系结构相比,它的可训练参数数量也要少得多,并且可以使用随机梯度下降法进行端到端训练。
在解码器中重复使用max-pooling 的索引的好处有:
SegNet有一个编码器网络和相应的解码器网络,然后是最后的像素级分类层。编码器网络由13个卷积层组成,对应于vgg16网络中的前13个卷积层。丢弃完全连接的层,以便在最深的编码器输出端保留更高分辨率的特征图。这也大大减少了SegNet编码器网络中的参数数量。每个编码器层有一个对应的解码器层,因此解码器网络有13层。最后的解码器输出被馈送到一个多类软最大分类器,为每个像素独立产生类概率。
相比于其他网络:
DeconvNet的参数化要大得多,需要更多的计算资源,而且端到端训练也比较困难,这主要是由于使用了完全连接的层(尽管是以卷积的方式)。与SegNet相比,U-Net(为医学成像社区提出)不重用池索引,而是将整个feature map(以更多内存为代价)传输到相应的解码器,并将它们连接到上采样(通过反褶积)解码器feature map。U-Net中的vggnet结构没有conv5和max-pool5 。另一方面,SegNet使用来自VGG net的所有预训练卷积层权重作为预训练权重。
下面不同网络比较的结果。
假设下图中a、b、c、d对应于feature map中的值。SegNet使用Max pooling 索引向上采样(不需要学习)特征映射,并与可训练的 decoder filters 组进行卷积。FCN通过学习对输入的feature map进行解卷积,并添加相应的encoder feature map来产生decoder output,从而对FCN进行upsamples。该feature map是对应编码器中的max-pooling层(包括子采样)的输出。注意,FCN中没有可训练的 decoder filters 。这里的 decoder filters 就是指解码器中的输入输出大小不变的那部分卷积层,而不是转置卷积或者解卷积。
代码地址:https://github.com/delta-onera/segnet_pytorch/blob/master/segnet.py
代码是基于pytorch,因为基本的结构是vgg16作为编码器,与vgg16相反的结构作为解码器,所以代码可以直观理解,唯一特别注意的是如何实现max pooling 索引的问题。
这里有两个pytorch里的函数已经帮助我们实现了功能,分别是max_pool2d,max_unpool2d。
torch.nn.functional.max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False)
torch.nn.functional.max_unpool2d(input, indices, kernel_size, stride=None, padding=0, output_size=None)
参数:
- input – 输入的张量 (minibatch x in_channels x iH x iW)
- kernel_size – 池化区域的大小,可以是单个数字或者元组 (kh x kw)
- stride – 池化操作的步长,可以是单个数字或者元组 (sh x sw)。默认等于核的大小
- padding – 在输入上隐式的零填充,可以是单个数字或者一个元组 (padh x padw),默认: 0
- ceil_mode – 定义空间输出形状的操作
- count_include_pad – 除以原始非填充图像内的元素数量或kh * kw
-return_indices – 返回索引
若返回索引设置为True,max_pool2d就会返回输出和索引值,这个索引值就能被max_unpool2d所设置为indices。
就如下面那样:
x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)
x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
最后奉上完整代码。
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
class SegNet(nn.Module):
def __init__(self,input_nbr,label_nbr):
super(SegNet, self).__init__()
batchNorm_momentum = 0.1
self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)
self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
self.bn41d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1)
self.bn31d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
self.bn21d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1)
def forward(self, x):
# Stage 1
x11 = F.relu(self.bn11(self.conv11(x)))
x12 = F.relu(self.bn12(self.conv12(x11)))
x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True)
# Stage 2
x21 = F.relu(self.bn21(self.conv21(x1p)))
x22 = F.relu(self.bn22(self.conv22(x21)))
x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True)
# Stage 3
x31 = F.relu(self.bn31(self.conv31(x2p)))
x32 = F.relu(self.bn32(self.conv32(x31)))
x33 = F.relu(self.bn33(self.conv33(x32)))
x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True)
# Stage 4
x41 = F.relu(self.bn41(self.conv41(x3p)))
x42 = F.relu(self.bn42(self.conv42(x41)))
x43 = F.relu(self.bn43(self.conv43(x42)))
x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True)
# Stage 5
x51 = F.relu(self.bn51(self.conv51(x4p)))
x52 = F.relu(self.bn52(self.conv52(x51)))
x53 = F.relu(self.bn53(self.conv53(x52)))
x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)
# Stage 5d
x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
x53d = F.relu(self.bn53d(self.conv53d(x5d)))
x52d = F.relu(self.bn52d(self.conv52d(x53d)))
x51d = F.relu(self.bn51d(self.conv51d(x52d)))
# Stage 4d
x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
x43d = F.relu(self.bn43d(self.conv43d(x4d)))
x42d = F.relu(self.bn42d(self.conv42d(x43d)))
x41d = F.relu(self.bn41d(self.conv41d(x42d)))
# Stage 3d
x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
x33d = F.relu(self.bn33d(self.conv33d(x3d)))
x32d = F.relu(self.bn32d(self.conv32d(x33d)))
x31d = F.relu(self.bn31d(self.conv31d(x32d)))
# Stage 2d
x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
x22d = F.relu(self.bn22d(self.conv22d(x2d)))
x21d = F.relu(self.bn21d(self.conv21d(x22d)))
# Stage 1d
x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
x12d = F.relu(self.bn12d(self.conv12d(x1d)))
x11d = self.conv11d(x12d)
return x11d
def load_from_segnet(self, model_path):
s_dict = self.state_dict()# create a copy of the state dict
th = torch.load(model_path).state_dict() # load the weigths
# for name in th:
# s_dict[corresp_name[name]] = th[name]
self.load_state_dict(th)