unet网络python代码详解_使用pytorch实现论文中的unet网络

设计神经网络的一般步骤:

1. 设计框架

2. 设计骨干网络

Unet网络设计的步骤:

1. 设计Unet网络工厂模式

2. 设计编解码结构

3. 设计卷积模块

4. unet实例模块

Unet网络最重要的特征:

1. 编解码结构。

2. 解码结构,比FCN更加完善,采用连接方式。

3. 本质是一个框架,编码部分可以使用很多图像分类网络。

示例代码:

import torch

import torch.nn as nn

class Unet(nn.Module):

#初始化参数:Encoder,Decoder,bridge

#bridge默认值为无,如果有参数传入,则用该参数替换None

def __init__(self,Encoder,Decoder,bridge = None):

super(Unet,self).__init__()

self.encoder = Encoder(encoder_blocks)

self.decoder = Decoder(decoder_blocks)

self.bridge = bridge

def forward(self,x):

res = self.encoder(x)

out,skip = res[0],res[1,:]

if bridge is not None:

out = bridge(out)

out = self.decoder(out,skip)

return out

#设计编码模块

class Encoder(nn.Module):

def __init__(self,blocks):

super(Encoder,self).__init__()

#assert:断言函数,避免出现参数错误

assert len(blocks) > 0

#nn.Modulelist():模型列表,所有的参数可以纳入网络,但是没有forward函数

self.blocks = nn.Modulelist(blocks)

def forward(self,x):

skip = []

for i in range(len(self.blocks) - 1):

x = self.blocks[i](x)

skip.append(x)

res = [self.block[i+1](x)]

#列表之间可以通过+号拼接

res += skip

return res

#设计Decoder模块

class Decoder(nn.Module):

def __init__(self,blocks):

super(Decoder, self).__init__()

assert len(blocks) > 0

self.blocks = nn.Modulelist(blocks)

def ceter_crop(self,skips,x):

_,_,height1,width1 = skips.shape()

_,_,height2,width2 = x.shape()

#对图像进行剪切处理,拼接的时候保持对应size参数一致

ht,wt = min(height1,height2),min(width1,width2)

dh1 = (height1 - height2)//2 if height1 > height2 else 0

dw1 = (width1 - width2)//2 if width1 > width2 else 0

dh2 = (height2 - height1)//2 if height2 > height1 else 0

dw2 = (width2 - width1)//2 if width2 > width1 else 0

return skips[:,:,dh1:(dh1 + ht),dw1:(dw1 + wt)],\

x[:,:,dh2:(dh2 + ht),dw2 : (dw2 + wt)]

def forward(self, skips,x,reverse_skips = True):

assert len(skips) == len(blocks) - 1

if reverse_skips is True:

skips = skips[: : -1]

x = self.blocks[0](x)

for i in range(1, len(self.blocks)):

skip = skips[i-1]

x = torch.cat(skip,x,1)

x = self.blocks[i](x)

return x

#定义了一个卷积block

def unet_convs(in_channels,out_channels,padding = 0):

#nn.Sequential:与Modulelist相比,包含了forward函数

return nn.Sequential(

nn.Conv2d(in_channels, out_channels, kernal_size = 3, padding = padding, bias = False),

nn.BatchNorm2d(outchannels),

nn.ReLU(inplace = True),

nn.Conv2d(in_channels, out_channels, kernal_size=3, padding=padding, bias=False),

nn.BatchNorm2d(outchannels),

nn.ReLU(inplace=True),

)

#实例化Unet模型

def unet(in_channels,out_channels):

encoder_blocks = [unet_convs(in_channels, 64),\

nn.Sequential(nn.Maxpool2d(kernal_size = 2, stride = 2, ceil_mode = True),\

unet_convs(64,128)), \

nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \

unet_convs(128, 256)),

nn.Sequential(nn.Maxpool2d(kernal_size=2, stride=2, ceil_mode=True), \

unet_convs(256, 512)),

]

bridge = nn.Sequential(unet_convs(512, 1024))

decoder_blocks = [nn.conTranpose2d(1024, 512), \

nn.Sequential(unet_convs(1024, 512),

nn.conTranpose2d(512, 256)),\

nn.Sequential(unet_convs(512, 256),

nn.conTranpose2d(256, 128)), \

nn.Sequential(unet_convs(512, 256),

nn.conTranpose2d(256, 128)), \

nn.Sequential(unet_convs(256, 128),

nn.conTranpose2d(128, 64))

]

return Unet(encoder_blocks,decoder_blocks,bridge)

补充知识:Pytorch搭建U-Net网络

U-Net: Convolutional Networks for Biomedical Image Segmentation

import torch.nn as nn

import torch

from torch import autograd

from torchsummary import summary

class DoubleConv(nn.Module):

def __init__(self, in_ch, out_ch):

super(DoubleConv, self).__init__()

self.conv = nn.Sequential(

nn.Conv2d(in_ch, out_ch, 3, padding=0),

nn.BatchNorm2d(out_ch),

nn.ReLU(inplace=True),

nn.Conv2d(out_ch, out_ch, 3, padding=0),

nn.BatchNorm2d(out_ch),

nn.ReLU(inplace=True)

)

def forward(self, input):

return self.conv(input)

class Unet(nn.Module):

def __init__(self, in_ch, out_ch):

super(Unet, self).__init__()

self.conv1 = DoubleConv(in_ch, 64)

self.pool1 = nn.MaxPool2d(2)

self.conv2 = DoubleConv(64, 128)

self.pool2 = nn.MaxPool2d(2)

self.conv3 = DoubleConv(128, 256)

self.pool3 = nn.MaxPool2d(2)

self.conv4 = DoubleConv(256, 512)

self.pool4 = nn.MaxPool2d(2)

self.conv5 = DoubleConv(512, 1024)

# 逆卷积,也可以使用上采样

self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)

self.conv6 = DoubleConv(1024, 512)

self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)

self.conv7 = DoubleConv(512, 256)

self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)

self.conv8 = DoubleConv(256, 128)

self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)

self.conv9 = DoubleConv(128, 64)

self.conv10 = nn.Conv2d(64, out_ch, 1)

def forward(self, x):

c1 = self.conv1(x)

crop1 = c1[:,:,88:480,88:480]

p1 = self.pool1(c1)

c2 = self.conv2(p1)

crop2 = c2[:,:,40:240,40:240]

p2 = self.pool2(c2)

c3 = self.conv3(p2)

crop3 = c3[:,:,16:120,16:120]

p3 = self.pool3(c3)

c4 = self.conv4(p3)

crop4 = c4[:,:,4:60,4:60]

p4 = self.pool4(c4)

c5 = self.conv5(p4)

up_6 = self.up6(c5)

merge6 = torch.cat([up_6, crop4], dim=1)

c6 = self.conv6(merge6)

up_7 = self.up7(c6)

merge7 = torch.cat([up_7, crop3], dim=1)

c7 = self.conv7(merge7)

up_8 = self.up8(c7)

merge8 = torch.cat([up_8, crop2], dim=1)

c8 = self.conv8(merge8)

up_9 = self.up9(c8)

merge9 = torch.cat([up_9, crop1], dim=1)

c9 = self.conv9(merge9)

c10 = self.conv10(c9)

out = nn.Sigmoid()(c10)

return out

if __name__=="__main__":

test_input=torch.rand(1, 1, 572, 572)

model=Unet(in_ch=1, out_ch=2)

summary(model, (1,572,572))

ouput=model(test_input)

print(ouput.size())

----------------------------------------------------------------

Layer (type) Output Shape Param #

================================================================

Conv2d-1 [-1, 64, 570, 570] 640

BatchNorm2d-2 [-1, 64, 570, 570] 128

ReLU-3 [-1, 64, 570, 570] 0

Conv2d-4 [-1, 64, 568, 568] 36,928

BatchNorm2d-5 [-1, 64, 568, 568] 128

ReLU-6 [-1, 64, 568, 568] 0

DoubleConv-7 [-1, 64, 568, 568] 0

MaxPool2d-8 [-1, 64, 284, 284] 0

Conv2d-9 [-1, 128, 282, 282] 73,856

BatchNorm2d-10 [-1, 128, 282, 282] 256

ReLU-11 [-1, 128, 282, 282] 0

Conv2d-12 [-1, 128, 280, 280] 147,584

BatchNorm2d-13 [-1, 128, 280, 280] 256

ReLU-14 [-1, 128, 280, 280] 0

DoubleConv-15 [-1, 128, 280, 280] 0

MaxPool2d-16 [-1, 128, 140, 140] 0

Conv2d-17 [-1, 256, 138, 138] 295,168

BatchNorm2d-18 [-1, 256, 138, 138] 512

ReLU-19 [-1, 256, 138, 138] 0

Conv2d-20 [-1, 256, 136, 136] 590,080

BatchNorm2d-21 [-1, 256, 136, 136] 512

ReLU-22 [-1, 256, 136, 136] 0

DoubleConv-23 [-1, 256, 136, 136] 0

MaxPool2d-24 [-1, 256, 68, 68] 0

Conv2d-25 [-1, 512, 66, 66] 1,180,160

BatchNorm2d-26 [-1, 512, 66, 66] 1,024

ReLU-27 [-1, 512, 66, 66] 0

Conv2d-28 [-1, 512, 64, 64] 2,359,808

BatchNorm2d-29 [-1, 512, 64, 64] 1,024

ReLU-30 [-1, 512, 64, 64] 0

DoubleConv-31 [-1, 512, 64, 64] 0

MaxPool2d-32 [-1, 512, 32, 32] 0

Conv2d-33 [-1, 1024, 30, 30] 4,719,616

BatchNorm2d-34 [-1, 1024, 30, 30] 2,048

ReLU-35 [-1, 1024, 30, 30] 0

Conv2d-36 [-1, 1024, 28, 28] 9,438,208

BatchNorm2d-37 [-1, 1024, 28, 28] 2,048

ReLU-38 [-1, 1024, 28, 28] 0

DoubleConv-39 [-1, 1024, 28, 28] 0

ConvTranspose2d-40 [-1, 512, 56, 56] 2,097,664

Conv2d-41 [-1, 512, 54, 54] 4,719,104

BatchNorm2d-42 [-1, 512, 54, 54] 1,024

ReLU-43 [-1, 512, 54, 54] 0

Conv2d-44 [-1, 512, 52, 52] 2,359,808

BatchNorm2d-45 [-1, 512, 52, 52] 1,024

ReLU-46 [-1, 512, 52, 52] 0

DoubleConv-47 [-1, 512, 52, 52] 0

ConvTranspose2d-48 [-1, 256, 104, 104] 524,544

Conv2d-49 [-1, 256, 102, 102] 1,179,904

BatchNorm2d-50 [-1, 256, 102, 102] 512

ReLU-51 [-1, 256, 102, 102] 0

Conv2d-52 [-1, 256, 100, 100] 590,080

BatchNorm2d-53 [-1, 256, 100, 100] 512

ReLU-54 [-1, 256, 100, 100] 0

DoubleConv-55 [-1, 256, 100, 100] 0

ConvTranspose2d-56 [-1, 128, 200, 200] 131,200

Conv2d-57 [-1, 128, 198, 198] 295,040

BatchNorm2d-58 [-1, 128, 198, 198] 256

ReLU-59 [-1, 128, 198, 198] 0

Conv2d-60 [-1, 128, 196, 196] 147,584

BatchNorm2d-61 [-1, 128, 196, 196] 256

ReLU-62 [-1, 128, 196, 196] 0

DoubleConv-63 [-1, 128, 196, 196] 0

ConvTranspose2d-64 [-1, 64, 392, 392] 32,832

Conv2d-65 [-1, 64, 390, 390] 73,792

BatchNorm2d-66 [-1, 64, 390, 390] 128

ReLU-67 [-1, 64, 390, 390] 0

Conv2d-68 [-1, 64, 388, 388] 36,928

BatchNorm2d-69 [-1, 64, 388, 388] 128

ReLU-70 [-1, 64, 388, 388] 0

DoubleConv-71 [-1, 64, 388, 388] 0

Conv2d-72 [-1, 2, 388, 388] 130

================================================================

Total params: 31,042,434

Trainable params: 31,042,434

Non-trainable params: 0

----------------------------------------------------------------

Input size (MB): 1.25

Forward/backward pass size (MB): 3280.59

Params size (MB): 118.42

Estimated Total Size (MB): 3400.26

----------------------------------------------------------------

torch.Size([1, 2, 388, 388])

以上这篇使用pytorch实现论文中的unet网络就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

你可能感兴趣的:(unet网络python代码详解_使用pytorch实现论文中的unet网络)