Pytorch搭建U-Net网络

1、Pytorch

原来常用keras搭建网络模型,后来发现keras的训练模型速度和测试速度都较慢,因此转向使用pytorch,其实两者使用难度差不多,都是高层的深度学习框架,适合研究深度学习。

2、U_Net网络介绍

U_Net网络已经提出很早,常被用在图像语义分割领域。模型的主要结构如下图所示,包括下采样和上采样两个过程。为了保证上采样得到的特征图具有较强的语义信息、提高分割的精准度。会在上采样过程中进行通道拼接再卷积。

Pytorch搭建U-Net网络_第1张图片

3、Pytorch代码

3.1、导入包

导入torch,至于为什么导入numpy,显然是因为喜欢。

import torch
from torch import nn
import numpy as np

3.2、下采样模块

下采样模块本文采用了通用卷积进行搭建,当然在语义分割中用的较多的空洞卷积,以及残差结构对网络性能都是有提升效果的。BN层的作用当然是网络节点输出更加稳定,一定程度上能够缓解梯度爆炸和梯度消失问题。激活函数这里使用了Relu6,同样也是考虑到了数据分布,因为通常图片数据在进入网络模型前会进行标准化处理。

class block_down(nn.Module):
    
    def __init__(self,inp_channel,out_channel):
        super(block_down,self).__init__()
        self.conv1=nn.Conv2d(inp_channel,out_channel,3,padding=1)
        self.conv2=nn.Conv2d(out_channel,out_channel,3,padding=1)
        self.bn=nn.BatchNorm2d(out_channel)
        self.relu=nn.ReLU6(inplace=True)
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn(x)
        x=self.relu(x)
        x=self.conv2(x)
        x=self.bn(x)
        x=self.relu(x)
        return x

3.3、上采样模块

上采样模块先使用转置卷积进行~~额-->上采样。

常规卷积的输入和输出尺寸关系是:

out_size=(inp_size-f+2p)/stride +1

转置卷积为:

out_size=(inp_size-1)*stride +f

式中f是卷积核(kernel)的尺寸,stride是卷积核滑动步长。

转置卷积的作用显而易见是~~额-->回到过去。

class block_up(nn.Module):
    
    def __init__(self,inp_channel,out_channel,y):
        super(block_up,self).__init__()
        self.up=nn.ConvTranspose2d(inp_channel,out_channel,2,stride=2)
        self.conv1=nn.Conv2d(inp_channel,out_channel,3,padding=1)
        self.conv2=nn.Conv2d(out_channel,out_channel,3,padding=1)
        self.bn=nn.BatchNorm2d(out_channel)
        self.relu=nn.ReLU6(inplace=True)
        self.y=y

    def forward(self,x):
        x=self.up(x)
        x=torch.cat([x,self.y],dim=1)
        x=self.conv1(x)
        x=self.bn(x)
        x=self.relu(x)
        x=self.conv2(x)
        x=self.bn(x)
        x=self.relu(x)
        return x

3.4、用模块搭建整体网络

至于为什么写block再搭建会显得有点麻烦,原因一是结构清楚,二是超参数更易修改,三是显而易见是我喜欢。

class U_net(nn.Module):
    
    def __init__(self,out_channel):
        super(U_net,self).__init__()
        self.out=nn.Conv2d(64,out_channel,1)
        self.maxpool=nn.MaxPool2d(2)
        
    def forward(self,x):
        block1=block_down(3,64)
        x1_use=block1(x)
        x1=self.maxpool(x1_use)
        block2=block_down(64,128)
        x2_use=block2(x1)
        x2=self.maxpool(x2_use)
        block3=block_down(128,256)
        x3_use=block3(x2)
        x3=self.maxpool(x3_use)
        block4=block_down(256,512)
        x4_use=block4(x3)
        x4=self.maxpool(x4_use)
        block5=block_down(512,1024)
        x5=block5(x4)

        block6=block_up(1024,512,x4_use)
        x6=block6(x5)
        block7=block_up(512,256,x3_use)
        x7=block7(x6)
        block8=block_up(256,128,x2_use)
        x8=block8(x7)
        block9=block_up(128,64,x1_use)
        x9=block9(x8)
        x10=self.out(x9)
        out=nn.Softmax2d()(x10)
        return out 

4、测试

输入形状和输出相同,完成搭建。

input_size: torch.Size([1, 3, 480, 640])
output_size: torch.Size([1, 3, 480, 640])
if __name__=="__main__":
    test_input=torch.rand(1, 3, 480, 640)
    print("input_size:",test_input.size())
    model=U_net(out_channel=3)
    ouput=model(test_input)
    print("output_size:",ouput.size())

 

 

你可能感兴趣的:(语义分割)