用于分类任务的卷积神经网络输出一般都是一个单一的类标签,但是在很多图像视觉任务中往往要求输出信息包含位置信息,比如给图像的每个像素点赋予类别信息。
这样就需要使用到fully convolutional network 相较于FCN,UNet不只是简单的将卷积网络的结果直接进行转置卷积,而是结合了前面卷积层的中间计算结果,这样既利用了卷积部分的位置信息又利用了最后的分类信息。
UNet分为两个阶段:downsample stage 和 upsample stage。
图中的每个蓝色框对应了每层的feature map,上面数字表示当前的通道数。左下角的数字表示feature map的尺寸,这里输入为572x572通道数为1。白色的框表示将对应的卷积层结果进行crop后的结果,白色框与蓝色框进行组合,作为后续网络的输入。
因为在前面卷积的过程中没有添加padding,因此会存在feature map的尺寸变化,所以在上采样过程中,需要对每个阶段的卷积计算结果进行crop操作然后再与转置卷积的结果进行concat(通道维度的相加)。
def concat(tensor1,tensor2):
# concat 2 tensor by the channel axes
tensor1,tensor2 = (tensor1,tensor2) if tensor1.size()[3]>=tensor2.size()[3] else (tensor2,tensor1)
crop_val = int((tensor1.size()[3]-tensor2.size()[3])/2)
tensor1 = tensor1[:, :, crop_val:tensor1.size()[3]-crop_val
, crop_val:tensor1.size()[3]-crop_val]
return torch.cat((tensor1,tensor2),1)
完整代码:
import torch
import torch.nn as nn
from torchsummary import summary
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution without padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride)
def up_conv2x2(in_planes,out_planes):
return nn.ConvTranspose2d(in_planes,out_planes,kernel_size=2,stride=2)
def max_pool2x2():
return nn.MaxPool2d(kernel_size=2,stride=2)
class UNet(nn.Module):
def __init__(self,class_num=1000):
super(UNet, self).__init__()
# downsample stage
self.conv_1 = nn.Sequential(conv3x3(1,64),conv3x3(64,64))
self.conv_2 = nn.Sequential(conv3x3(64,128),conv3x3(128,128))
self.conv_3 = nn.Sequential(conv3x3(128,256),conv3x3(256,256))
self.conv_4 = nn.Sequential(conv3x3(256,512),conv3x3(512,512))
self.conv_5 = nn.Sequential(conv3x3(512,1024),conv3x3(1024,1024))
self.maxpool = max_pool2x2()
# upsample stage
# up_conv_4 corresponds conv_4
self.up_conv_4 = nn.Sequential(up_conv2x2(1024,512))
# conv the cat(stage_4,up_conv_4) from 1024 to 512
self.conv_6 = nn.Sequential(conv3x3(1024,512),conv3x3(512,512))
# up_conv_3 corresponds conv_3
self.up_conv_3 = nn.Sequential(up_conv2x2(512,256))
# conv the cat(stage_3,up_conv_3) from 512 to 256
self.conv_7 = nn.Sequential(conv3x3(512,256),conv3x3(256,256))
# up_conv_2 corresponds conv_2
self.up_conv_2 = nn.Sequential(up_conv2x2(256,128))
# conv the cat(stage_2,up_conv_2) from 256 to 128
self.conv_8 = nn.Sequential(conv3x3(256,128),conv3x3(128,128))
# up_conv_1 corresponds conv_1
self.up_conv_1 = nn.Sequential(up_conv2x2(128,64))
# conv the cat(stage_1,up_conv_1) from 128 to 64
self.conv_9 = nn.Sequential(conv3x3(128,64),conv3x3(64,64))
# output
self.result = conv1x1(64,2)
def _concat(self,tensor1,tensor2):
# concat 2 tensor by the channel axes
tensor1,tensor2 = (tensor1,tensor2) if tensor1.size()[3]>=tensor2.size()[3] else (tensor2,tensor1)
crop_val = int((tensor1.size()[3]-tensor2.size()[3])/2)
tensor1 = tensor1[:, :, crop_val:tensor1.size()[3]-crop_val
, crop_val:tensor1.size()[3]-crop_val]
return torch.cat((tensor1,tensor2),1)
def forward(self,x):
# get 4 stage conv output
stage_1 = self.conv_1(x)
stage_2 = self.conv_2(self.maxpool(stage_1))
stage_3 = self.conv_3(self.maxpool(stage_2))
stage_4 = self.conv_4(self.maxpool(stage_3))
# get up_conv_4 and concat with stage_4
up_in_4 = self.conv_5(self.maxpool(stage_4))
up_stage_4 = self.up_conv_4(up_in_4)
up_stage_4 = self._concat(stage_4,up_stage_4)
# get up_conv_3 and concat with stage_3
up_in_3 = self.conv_6(up_stage_4)
up_stage_3 = self.up_conv_3(up_in_3)
up_stage_3 = self._concat(stage_3,up_stage_3)
# get up_conv_2 and concat with stage_2
up_in_2 = self.conv_7(up_stage_3)
up_stage_2 = self.up_conv_2(up_in_2)
up_stage_2 = self._concat(stage_2,up_stage_2)
# get up_conv_1 and concat with stage_1
up_in_1 = self.conv_8(up_stage_2)
up_stage_1 = self.up_conv_1(up_in_1)
up_stage_1 = self._concat(stage_1,up_stage_1)
# last conv to channel 2
out = self.conv_9(up_stage_1)
# result
out = self.result(out)
return out
if __name__ == '__main__':
ut = UNet(12)
summary(ut,(1,572,572))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 570, 570] 640
Conv2d-2 [-1, 64, 568, 568] 36,928
MaxPool2d-3 [-1, 64, 284, 284] 0
Conv2d-4 [-1, 128, 282, 282] 73,856
Conv2d-5 [-1, 128, 280, 280] 147,584
MaxPool2d-6 [-1, 128, 140, 140] 0
Conv2d-7 [-1, 256, 138, 138] 295,168
Conv2d-8 [-1, 256, 136, 136] 590,080
MaxPool2d-9 [-1, 256, 68, 68] 0
Conv2d-10 [-1, 512, 66, 66] 1,180,160
Conv2d-11 [-1, 512, 64, 64] 2,359,808
MaxPool2d-12 [-1, 512, 32, 32] 0
Conv2d-13 [-1, 1024, 30, 30] 4,719,616
Conv2d-14 [-1, 1024, 28, 28] 9,438,208
ConvTranspose2d-15 [-1, 512, 56, 56] 2,097,664
Conv2d-16 [-1, 512, 54, 54] 4,719,104
Conv2d-17 [-1, 512, 52, 52] 2,359,808
ConvTranspose2d-18 [-1, 256, 104, 104] 524,544
Conv2d-19 [-1, 256, 102, 102] 1,179,904
Conv2d-20 [-1, 256, 100, 100] 590,080
ConvTranspose2d-21 [-1, 128, 200, 200] 131,200
Conv2d-22 [-1, 128, 198, 198] 295,040
Conv2d-23 [-1, 128, 196, 196] 147,584
ConvTranspose2d-24 [-1, 64, 392, 392] 32,832
Conv2d-25 [-1, 64, 390, 390] 73,792
Conv2d-26 [-1, 64, 388, 388] 36,928
Conv2d-27 [-1, 2, 388, 388] 130
================================================================
Total params: 31,030,658
Trainable params: 31,030,658
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.25
Forward/backward pass size (MB): 1096.59
Params size (MB): 118.37
Estimated Total Size (MB): 1216.21
----------------------------------------------------------------
ckward pass size (MB): 1096.59
Params size (MB): 118.37
Estimated Total Size (MB): 1216.21
----------------------------------------------------------------