图像分割模型-Unet及代码

Unet网络结构:
图像分割模型-Unet及代码_第1张图片

(1)UNet网络为一个U形结构。采用全卷积神经网络,没有全连接操作

(2)左边encoder编码器部分为特征提取网络:使用conv和pooling进行下采样

(3)右边decoder解码器部分为特征融合网络:右侧上采样产生的特征图与左侧下采样的特征图在channel维度上进行concatenate拼接操作。(图片的维度为:B C H W,分别为batchsize,channel, height, width)

上采样的目的:pooling池化层使得图片宽高减半,会丢失图像信息,降低分辨率。上采样可以提高图片分辨率,并且保留高级抽象特征,然后再与左边低级表层特征高分辨率图片拼接。

上采样方法:使用转置卷积nn.ConvTranspose2d()代替简单的插值上采样方法,既能实现同样的效果,也能加深网络。

(4)最后再经过两次3*3卷积操作,再用1*1的卷积核,输出channel维度为需要分割的类别数num_classes,生成维度为(B,num_classes,H,W)的特征图。

code: 

import torch
import torch.nn as nn
import torch.nn.functional as F


def X2conv(in_channel,out_channel):
    """连续两个3*3卷积"""
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channel),
        nn.ReLU())

class DownsampleLayer(nn.Module):
    """
    下采样层
    """
    def __init__(self,in_channel,out_channel):
        super(DownsampleLayer, self).__init__()
        self.x2conv=X2conv(in_channel,out_channel)
        self.pool=nn.MaxPool2d(kernel_size=2,ceil_mode=True)

    def forward(self,x):
        """
        :param x:上一层pool后的特征
        :return: out_1转入右侧(待拼接),out_1输入到下一层,
        """
        out_1=self.x2conv(x)
        out=self.pool(out_1)
        return out_1,out

class UpSampleLayer(nn.Module):
    """
    上采样层
    """
    def __init__(self,in_channel,out_channel):

        super(UpSampleLayer, self).__init__()
        self.x2conv = X2conv(in_channel, out_channel)
        self.upsample=nn.ConvTranspose2d\ 
(in_channels=out_channel,out_channels=out_channel//2,kernel_size=3,stride=2,padding=1)

    def forward(self,x,out):
        '''
        :param x: decoder中:输入层特征,经过x2conv与上采样upsample,然后拼接
        :param out:左侧encoder层中特征(与右侧上采样层进行cat)
        :return:
        '''
        x=self.x2conv(x)
        x=self.upsample(x)

        # x.shape中H W 应与 out.shape中的H W相同
        if (x.size(2) != out.size(2)) or (x.size(3) != out.size(3)):
            # 将右侧特征H W大小插值变为左侧特征H W大小
            x = F.interpolate(x, size=(out.size(2), out.size(3)),
                            mode="bilinear", align_corners=True)


        # Concatenate(在channel维度)
        cat_out = torch.cat([x, out], dim=1)
        return cat_out

class UNet(nn.Module):
    """
    UNet模型,num_classes为分割类别数
    """
    def __init__(self,num_classes):
        super(UNet, self).__init__()
        #下采样
        self.d1=DownsampleLayer(3,64) #3-64
        self.d2=DownsampleLayer(64,128)#64-128
        self.d3=DownsampleLayer(128,256)#128-256
        self.d4=DownsampleLayer(256,512)#256-512

        #上采样
        self.u1=UpSampleLayer(512,1024)#512-1024-512
        self.u2=UpSampleLayer(1024,512)#1024-512-256
        self.u3=UpSampleLayer(512,256)#512-256-128
        self.u4=UpSampleLayer(256,128)#256-128-64

        #输出:经过一个二层3*3卷积 + 1个1*1卷积
        self.x2conv=X2conv(128,64)
        self.final_conv=nn.Conv2d(64,num_classes,kernel_size=1)  # 最后一个卷积层的输出通道数为分割的类别数
        self._initialize_weights()

    def _initialize_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

    def forward(self,x):
        # 下采样层
        out_1,out1=self.d1(x)
        out_2,out2=self.d2(out1)
        out_3,out3=self.d3(out2)
        out_4,out4=self.d4(out3)

        # 上采样层 拼接
        out5=self.u1(out4,out_4)
        out6=self.u2(out5,out_3)
        out7=self.u3(out6,out_2)
        out8=self.u4(out7,out_1)

        # 最后的三层卷积
        out=self.x2conv(out8)
        out=self.final_conv(out)
        return out

if __name__ == "__main__":
    img = torch.randn((2, 3, 360, 480))  # 正态分布初始化

    model = UNet(num_classes=16)

    output = model(img)
    print(output.shape)

你可能感兴趣的:(图像分割模型,pytorch)