unet详解_UNet解释及Python实现

介绍

在图像分割中,机器必须将图像分割成不同的segments,每个segment代表不同的实体。

图像分割示例

正如你在上面看到的,图像如何变成两个部分,一个代表猫,另一个代表背景。图像分割在从自动驾驶汽车到卫星的许多领域都很有用。也许其中最重要的是医学影像。医学图像的微妙之处是相当复杂的。一台能够理解这些细微差别并识别出必要区域的机器,可以对医疗保健产生深远的影响。

卷积神经网络在简单的图像分割问题上取得了不错的效果,但在复杂的图像分割问题上却没有取得任何进展。这就是UNet的作用。UNet最初是专门为医学图像分割而设计的。该方法取得了良好的效果,并在以后的许多领域得到了应用。在本文中,我们将讨论UNet工作的原因和方式

UNet背后的直觉

卷积神经网络(CNN)背后的主要思想是学习图像的特征映射,并利用它进行更细致的特征映射。这在分类问题中很有效,因为图像被转换成一个向量,这个向量用于进一步的分类。但是在图像分割中,我们不仅需要将feature map转换成一个向量,还需要从这个向量重建图像。这是一项巨大的任务,因为要将向量转换成图像比反过来更困难。UNet的整个理念都围绕着这个问题。

在将图像转换为向量的过程中,我们已经学习了图像的特征映射,为什么不使用相同的映射将其再次转换为图像呢?这就是UNet背后的秘诀。用同样的 feature maps,将其用于contraction 来将矢量扩展成segmented image。这将保持图像的结构完整性,这将极大地减少失真。让我们更简单地理解架构。

UNet架构

UNet架构

该架构看起来像一个'U'。该体系结构由三部分组成:contraction,bottleneck和expansion 部分。contraction部分由许多contraction块组成。每个块接受一个输入,应用两个3X3的卷积层,然后是一个2X2的最大池化。在每个块之后,核或特征映射的数量会加倍,这样体系结构就可以有效地学习复杂的结构。最底层介于contraction层和expansion 层之间。它使用两个3X3 CNN层,然后是2X2 up convolution层。

这种架构的核心在于expansion 部分。与contraction层类似,它也包含几个expansion 块。每个块将输入传递到两个3X3 CNN层,然后是2X2上采样层。此外,卷积层使用的每个块的feature map数量得到一半,以保持对称性。每次输入也被相应的收缩层的 feature maps所附加。这个动作将确保在contracting 图像时学习到的特征将被用于重建图像。expansion 块的数量与contraction块的数量相同。之后,生成的映射通过另一个3X3 CNN层,feature map的数量等于所需的segment的数量。

UNet中的损失计算

UNet对每个像素使用了一种新颖的损失加权方案,使得分割对象的边缘具有更高的权重。这种损失加权方案帮助U-Net模型以不连续的方式分割生物医学图像中的细胞,以便在binary segmentation map中容易识别单个细胞。

首先,在所得图像上应用pixel-wise softmax,然后是交叉熵损失函数。所以我们将每个像素分类为一个类。我们的想法是,即使在分割中,每个像素都必须存在于某个类别中,我们只需要确保它们可以。因此,我们只是将分段问题转换为多类分类问题,与传统的损失函数相比,它表现得非常好。

UNet实现的Python代码

Python代码如下:

import torchfrom torch import nnimport torch.nn.functional as Fimport torch.optim as optimclass UNet(nn.Module): def contracting_block(self, in_channels, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), ) return block def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) ) return block def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3): block = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel), torch.nn.ReLU(), torch.nn.BatchNorm2d(mid_channel), torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1), torch.nn.ReLU(), torch.nn.BatchNorm2d(out_channels), ) return block def __init__(self, in_channel, out_channel): super(UNet, self).__init__() #Encode self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64) self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2) self.conv_encode2 = self.contracting_block(64, 128) self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2) self.conv_encode3 = self.contracting_block(128, 256) self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2) # Bottleneck self.bottleneck = torch.nn.Sequential( torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512), torch.nn.ReLU(), torch.nn.BatchNorm2d(512), torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512), torch.nn.ReLU(), torch.nn.BatchNorm2d(512), torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1) ) # Decode self.conv_decode3 = self.expansive_block(512, 256, 128) self.conv_decode2 = self.expansive_block(256, 128, 64) self.final_layer = self.final_block(128, 64, out_channel) def crop_and_concat(self, upsampled, bypass, crop=False): if crop: c = (bypass.size()[2] - upsampled.size()[2]) // 2 bypass = F.pad(bypass, (-c, -c, -c, -c)) return torch.cat((upsampled, bypass), 1) def forward(self, x): # Encode encode_block1 = self.conv_encode1(x) encode_pool1 = self.conv_maxpool1(encode_block1) encode_block2 = self.conv_encode2(encode_pool1) encode_pool2 = self.conv_maxpool2(encode_block2) encode_block3 = self.conv_encode3(encode_pool2) encode_pool3 = self.conv_maxpool3(encode_block3) # Bottleneck bottleneck1 = self.bottleneck(encode_pool3) # Decode decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True) cat_layer2 = self.conv_decode3(decode_block3) decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True) cat_layer1 = self.conv_decode2(decode_block2) decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True) final_layer = self.final_layer(decode_block1) return final_layer

以上Python代码中的UNet模块代表了UNet的整体架构。使用contracaction_block和expansive_block分别创建contraction部分和expansion部分。crop_and_concat函数的作用是将contraction层的输出添加到新的expansion层输入中。训练部分的Python代码可以写成

unet = Unet(in_channel=1,out_channel=2)#out_channel represents number of segments desiredcriterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)optimizer.zero_grad() outputs = unet(inputs)# permute such that number of desired segments would be on 4th dimensionoutputs = outputs.permute(0, 2, 3, 1)m = outputs.shape[0]# Resizing the outputs and label to caculate pixel wise softmax lossoutputs = outputs.resize(m*width_out*height_out, 2)labels = labels.resize(m*width_out*height_out)loss = criterion(outputs, labels)loss.backward()optimizer.step()

结论

图像分割是一个重要的问题,每天都有一些新的研究论文发表。UNet在这类研究中做出了重大贡献。许多新架构的灵感都来自UNet。在业界,这种体系结构有很多变体,因此有必要理解第一个变体,以便更好地理解它们。

本文仅代表作者个人观点,不代表SEO研究协会网官方发声,对观点有疑义请先联系作者本人进行修改,若内容非法请联系平台管理员,邮箱[email protected]。更多相关资讯,请到SEO研究协会网www.seoxiehui.cn学习互联网营销技术请到巨推学院www.jutuiedu.com。

你可能感兴趣的:(unet详解)