卷积神经网络多数被用在图片分类任务上,其输入是一张图片,输出是图片的类别(一般以向量形式呈现)。但是U-NET所完成的任务是像素级的分类,将每一个输入的像素归为某一类,不同的类以不同的颜色呈现,这样就可以做到输入一张图片,输出结果是图片上每一个像素的类别。
我们可以看到,U-NET将图片中每一个不同的部分且分开来,并可以用不同的颜色表示出来;同时U-NET对于训练图片数量要求不高,因为这些特性,U-NET长用于医学图像分割领域,并且基于U-NET发展出了很多其他的网络结构。
U-NET网络完全由卷积、上采样和下采样实现,没有MLP层。
U-NET之所以叫这个名字,是因为其网络结构是一个“U”型,故取名为U-NET。我们首先通过图片看一下U-NET的结构。
下面我们解释一下网络中不同位置的数字、不同颜色的箭头分别代表什么意思。网络结构中每一个矩形的大小、数字的位置等都是非常考究的,不同位置的数字有不同的意思,相同维度的矩阵一定要用相同大小的矩形表示。
首先我们输入维度为 [ 572 , 572 ] [572,572] [572,572]的单通道图像(可以看到,蓝色矩形表示代表的是图像或者网络从图像中抽取的特征,写在左侧的数字代表的是此时图像的维度,写在上方的数字代表图像的通道数,这个图相当于从侧面观察U-NET的运算过程),之后经过卷积,将通道数变为64,同时由于卷积的影响,图片的维度变为了 [ 570 , 570 ] [570,570] [570,570],其实这个时候将其成为图片已经不如将其称之为图片特征合适了。经过两次卷积之后,U-NET网络进行最大池化,将图片的两个维度都缩小了一半,浓缩了图像的特征,并继续卷积,提取特征,这就是网络左侧的计算过程。
左侧在论文中叫做Contracting Path,可见其目的就是提取图片特征。我们将其概括为以下步骤:
网络结构的右侧与左侧结构类似,右侧名为Expansive Path,目的是将左侧提取出的特征还原成原始图片大小并做到像素级别分类。
右侧的步骤概括如下:
对于拼接时尺寸稍大需要裁剪额过程,通过微调网络结构可以省略裁剪的步骤。比如在左侧进行卷积时,设置padding=1,可以使得卷积前后图像大小一样,只通过下采样改变图像的维度。
这里的拼接操作很重要,将上采样预测的特征和下采样获取的特征相拼接再进行预测,比起省略这一步骤,可以使网络有更好的分割效果。
论文中提到U-NET网络的loss计算使用交叉熵与softmax函数**(pytorch提供的crossEntropy里面自带了softmax函数,不需要额外添加)**,并且使用了高达0.99的momentum,使得之前梯度下降的方向也会参与到这次梯度下降方向的抉择中。
论文中对于loss的计算还进行了创新:为了凸显某些像素点更加重要,我们在公式中引入了 ω ( X ) ω(X) ω(X) ,并且对每一张标注图像预计算了一个权重图,来补偿训练集中每类像素的不同频率,使网络更注重学习相互接触的细胞之间的小的分割边界。
这使得论文中计算loss的交叉熵公式如下:
上文所说的权重图中的权重计算公式如下:
但是论文没有对应的原始代码
在Community Code中的实现,对于loss的处理或直接调用交叉熵或者nn.BCELoss()
完成。
由于U-NET用于医学图像处理,医学图像往往数量较少,同时在处理电镜、显微镜等图像时,需要网络具有平移、旋转不变性的特征,才能有更好的分割效果。所以论文中对训练集数据进行了数据增强,通过旋转、平移等手段增大了数据量,也增加了模型见识到的数据的类型。数据增强在模型训练中有比较重要的作用。
下采样部分复现论文中的结构,但是增加啦padding,使得卷积后特征的维度不变。
class double_conv2d_bn(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=3,strides=1,padding=1):
'''
padding=1目的是为了卷积后图片的形状不变
pool层将图像的维度减小一半 [length,width]==>[length/2,width/2]
'''
super(double_conv2d_bn,self).__init__()
self.conv1 = nn.Conv2d(in_channels,out_channels,
kernel_size=kernel_size,
stride = strides,padding=padding,bias=True)
self.conv2 = nn.Conv2d(out_channels,out_channels,
kernel_size = kernel_size,
stride = strides,padding=padding,bias=True)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self,x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
return out
反卷积使用了pytorch提供的ConvTranspose2d()
函数,当然也可以由别的实现方法。
class deconv2d_bn(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=2,strides=2):
super(deconv2d_bn,self).__init__()
self.conv1 = nn.ConvTranspose2d(in_channels,out_channels,
kernel_size = kernel_size,
stride = strides,bias=True)
self.bn1 = nn.BatchNorm2d(out_channels)
def forward(self,x):
out = F.relu(self.bn1(self.conv1(x)))
return out
这里的网络结构可以对照上文中的网络结构图,网络结构还是比较明晰的。
class Unet(nn.Module):
def __init__(self):
super(Unet,self).__init__()
self.layer1_conv = double_conv2d_bn(1,8)
self.layer2_conv = double_conv2d_bn(8,16)
self.layer3_conv = double_conv2d_bn(16,32)
self.layer4_conv = double_conv2d_bn(32,64)
self.layer5_conv = double_conv2d_bn(64,128)
self.layer6_conv = double_conv2d_bn(128,64)
self.layer7_conv = double_conv2d_bn(64,32)
self.layer8_conv = double_conv2d_bn(32,16)
self.layer9_conv = double_conv2d_bn(16,8)
self.layer10_conv = nn.Conv2d(8,1,kernel_size=3,
stride=1,padding=1,bias=True)
self.deconv1 = deconv2d_bn(128,64)
self.deconv2 = deconv2d_bn(64,32)
self.deconv3 = deconv2d_bn(32,16)
self.deconv4 = deconv2d_bn(16,8)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
conv1 = self.layer1_conv(x)
# print("第一层卷积后size{}".format(conv1.size()))
pool1 = F.max_pool2d(conv1,2)
# print("第一层池化后size{}".format(pool1.size()))
conv2 = self.layer2_conv(pool1)
# print("第二层卷积后size{}".format(conv2.size()))
pool2 = F.max_pool2d(conv2,2)
# print("第二层池化后size{}".format(pool2.size()))
conv3 = self.layer3_conv(pool2)
pool3 = F.max_pool2d(conv3,2)
conv4 = self.layer4_conv(pool3)
pool4 = F.max_pool2d(conv4,2)
conv5 = self.layer5_conv(pool4)
convt1 = self.deconv1(conv5)
# print(convt1.size())
# print(conv4.size())
concat1 = torch.cat([convt1,conv4],dim=1)
# print(concat1.size())
conv6 = self.layer6_conv(concat1)
convt2 = self.deconv2(conv6)
concat2 = torch.cat([convt2,conv3],dim=1)
conv7 = self.layer7_conv(concat2)
convt3 = self.deconv3(conv7)
concat3 = torch.cat([convt3,conv2],dim=1)
conv8 = self.layer8_conv(concat3)
convt4 = self.deconv4(conv8)
concat4 = torch.cat([convt4,conv1],dim=1)
conv9 = self.layer9_conv(concat4)
outp = self.layer10_conv(conv9)
# 由于暂时将网络应用于二分类问题,所以此处为sigmoid,如果是多分类,此处改为softmax函数
outp = self.sigmoid(outp)
return outp
网络分割的数据集采用了论文中提到的透射电镜拍摄的30幅果蝇幼虫腹神经索(VNC)的连续切片图像(512×512像素)。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5OkoYgKE-1667397841906)(https://cdn.jsdelivr.net/gh/wenruo-shusheng/BlogImageBed@main/img/segementation.png)]
此处保留了分割后图像的灰度特征(不是二值图像;当然也可以是)。
(outp)
return outp
## U-NET分割效果
网络分割的数据集采用了论文中提到的**透射电镜拍摄的30幅果蝇幼虫腹神经索(VNC)的连续切片图像(512×512像素)**。
[外链图片转存中...(img-5OkoYgKE-1667397841906)]
此处保留了分割后图像的灰度特征(不是二值图像;当然也可以是)。