U_net 网络(pytorch学习)

U_net 网络(pytorch学习)

 U_net是一个经典的图像分割网络,可以完成许多功能,在学习U_net网络后结合B站的视频尝试编写U_net代码,锻炼编程能力

一 U_net网络结构

  U_net网络的网络结构如下图所示:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-alsUIo1Z-1662831702017)(“C:\Users\23577\Desktop[2H[T27D8~{THTHD]]VPJD6.png”)]

网络模型代码

  步骤:

  • 先定义下采样网络既两个卷积

    import torch
    import torchvision
    from torch import nn
    
    
    class Double_conv(nn.Module):
        def __init__(self,in_channel, out_channel):
            super(Double_conv, self).__init__()
            """
            在这里使用卷积,保持图像尺寸不变,以便更好计算
            """
            self.layer = nn.Sequential(
                nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1,bias=False),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1,bias=False),
                nn.BatchNorm2d(out_channel),
                nn.ReLU(inplace=True),
            )
        def forward(self,x):
            return self.layer(x)
    urn self.layer(x)
    
    
  • 定义U_net整个网络模型

    class U_NET(nn.Module):
        def __init__(self, in_channel,out_channel,features=[64,128,256,512]):
            super(U_NET, self).__init__()
            self.DOWN = nn.ModuleList()
            self.UP = nn.ModuleList()
            self.maxpool = nn.MaxPool2d(2)
            for feature in features:
                self.DOWN.append(Double_conv(in_channel,feature))
                in_channel = feature
    
            for feature in reversed(features):
                self.UP.append(nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2))
                self.UP.append(Double_conv(feature*2,feature))
    
            self.botten = Double_conv(features[-1],features[-1]*2)
            self.final_conv = nn.Conv2d(features[0],out_channel,kernel_size=1,padding=0)
    
        def forward(self,x):
            skip_connect=[]
            for idx in self.DOWN:
                x = idx(x)
                skip_connect.append(x)
                x = self.maxpool(x)
            x = self.botten(x)
            skip_connect = skip_connect[::-1]
    
            for idx in range(0,len(self.UP),2):
                x = self.UP[idx](x)
                """
                为了适用,任意尺寸的图片特征融合时为保证尺寸相同Resize一下
                """
                if x.shape != skip_connect[idx // 2].shape:
                    x = torchvision.transforms.Resize( skip_connect[idx // 2].shape[2:])(x)
                x = torch.cat((x,skip_connect[idx//2]),dim=1)
                x = self.UP[idx+1](x)
            return self.final_conv(x)
    
  • 测试结果

    if __name__ == "__main__":
        x = torch.randn(1,1,161,161)
        model = U_NET(in_channel=1,out_channel=1)
        y = model(x)
        print(y.shape)
        # -> torch.Size([1, 1, 161, 161])
    

你可能感兴趣的:(pytorch)