SegNet网络的Pytorch实现

1.文章原文地址

SegNet: A Deep Convolutional Encoder-Decoder Architecture for Image Segmentation

2.文章摘要

语义分割具有非常广泛的应用,从场景理解、目标相互关系推断到自动驾驶。早期依赖于低水平视觉线索的方法已经快速的被流行的机器学习算法所取代。特别是最近的深度学习在手写数字识别、语音、图像中的分类和目标检测上取得巨大成功。如今有一个活跃的领域是语义分割(对每个像素进行归类)。然而,最近有一些方法直接采用了为图像分类而设计的网络结构来进行语义分割任务。虽然结果十分鼓舞人心,但还是比较粗糙。这首要的原因是最大池化和下采样减小了特征图的分辨率。我们设计SegNet的动机来自于分割任务需要将低分辨率的特征图映射到输入的分辨率并进行像素级分类,这个映射必须产生对准确边界定位有用的特征。

3.网络结构

SegNet网络的Pytorch实现_第1张图片

4.Pytorch实现

  1 import torch.nn as nn
  2 import torch
  3 
  4 class conv2DBatchNormRelu(nn.Module):
  5     def __init__(self,in_channels,out_channels,kernel_size,stride,padding,
  6                  bias=True,dilation=1,is_batchnorm=True):
  7         super(conv2DBatchNormRelu,self).__init__()
  8         if is_batchnorm:
  9             self.cbr_unit=nn.Sequential(
 10                 nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,
 11                           bias=bias,dilation=dilation),
 12                 nn.BatchNorm2d(out_channels),
 13                 nn.ReLU(inplace=True),
 14             )
 15         else:
 16             self.cbr_unit=nn.Sequential(
 17                 nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
 18                           bias=bias, dilation=dilation),
 19                 nn.ReLU(inplace=True)
 20             )
 21 
 22     def forward(self,inputs):
 23         outputs=self.cbr_unit(inputs)
 24         return outputs
 25 
 26 class segnetDown2(nn.Module):
 27     def __init__(self,in_channels,out_channels):
 28         super(segnetDown2,self).__init__()
 29         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
 30         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 31         self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)
 32 
 33     def forward(self,inputs):
 34         outputs=self.conv1(inputs)
 35         outputs=self.conv2(outputs)
 36         unpooled_shape=outputs.size()
 37         outputs,indices=self.maxpool_with_argmax(outputs)
 38         return outputs,indices,unpooled_shape
 39 
 40 class segnetDown3(nn.Module):
 41     def __init__(self,in_channels,out_channels):
 42         super(segnetDown3,self).__init__()
 43         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
 44         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 45         self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 46         self.maxpool_with_argmax=nn.MaxPool2d(kernel_size=2,stride=2,return_indices=True)
 47 
 48     def forward(self,inputs):
 49         outputs=self.conv1(inputs)
 50         outputs=self.conv2(outputs)
 51         outputs=self.conv3(outputs)
 52         unpooled_shape=outputs.size()
 53         outputs,indices=self.maxpool_with_argmax(outputs)
 54         return outputs,indices,unpooled_shape
 55 
 56 
 57 class segnetUp2(nn.Module):
 58     def __init__(self,in_channels,out_channels):
 59         super(segnetUp2,self).__init__()
 60         self.unpool=nn.MaxUnpool2d(2,2)
 61         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
 62         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 63 
 64     def forward(self,inputs,indices,output_shape):
 65         outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
 66         outputs=self.conv1(outputs)
 67         outputs=self.conv2(outputs)
 68         return outputs
 69 
 70 class segnetUp3(nn.Module):
 71     def __init__(self,in_channels,out_channels):
 72         super(segnetUp3,self).__init__()
 73         self.unpool=nn.MaxUnpool2d(2,2)
 74         self.conv1=conv2DBatchNormRelu(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
 75         self.conv2=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 76         self.conv3=conv2DBatchNormRelu(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
 77 
 78     def forward(self,inputs,indices,output_shape):
 79         outputs=self.unpool(inputs,indices=indices,output_size=output_shape)
 80         outputs=self.conv1(outputs)
 81         outputs=self.conv2(outputs)
 82         outputs=self.conv3(outputs)
 83         return outputs
 84 
 85 class segnet(nn.Module):
 86     def __init__(self,in_channels=3,num_classes=21):
 87         super(segnet,self).__init__()
 88         self.down1=segnetDown2(in_channels=in_channels,out_channels=64)
 89         self.down2=segnetDown2(64,128)
 90         self.down3=segnetDown3(128,256)
 91         self.down4=segnetDown3(256,512)
 92         self.down5=segnetDown3(512,512)
 93 
 94         self.up5=segnetUp3(512,512)
 95         self.up4=segnetUp3(512,256)
 96         self.up3=segnetUp3(256,128)
 97         self.up2=segnetUp2(128,64)
 98         self.up1=segnetUp2(64,64)
 99         self.finconv=conv2DBatchNormRelu(64,num_classes,3,1,1)
100 
101     def forward(self,inputs):
102         down1,indices_1,unpool_shape1=self.down1(inputs)
103         down2,indices_2,unpool_shape2=self.down2(down1)
104         down3,indices_3,unpool_shape3=self.down3(down2)
105         down4,indices_4,unpool_shape4=self.down4(down3)
106         down5,indices_5,unpool_shape5=self.down5(down4)
107 
108         up5=self.up5(down5,indices=indices_5,output_shape=unpool_shape5)
109         up4=self.up4(up5,indices=indices_4,output_shape=unpool_shape4)
110         up3=self.up3(up4,indices=indices_3,output_shape=unpool_shape3)
111         up2=self.up2(up3,indices=indices_2,output_shape=unpool_shape2)
112         up1=self.up1(up2,indices=indices_1,output_shape=unpool_shape1)
113         outputs=self.finconv(up1)
114 
115         return outputs
116 
117 if __name__=="__main__":
118     inputs=torch.ones(1,3,224,224)
119     model=segnet()
120     print(model(inputs).size())
121     print(model)

参考

https://github.com/meetshah1995/pytorch-semseg

转载于:https://www.cnblogs.com/ys99/p/10900870.html

你可能感兴趣的:(SegNet网络的Pytorch实现)