提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
找了2个代码,还考虑用H-DenseU-Net的代码。
import torch
import torch.nn as nn
class conv_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(conv_block, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class up_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(up_conv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.up(x)
return x
class Conv_Block(nn.Module):
def __init__(self, ch_in, ch_out):
super(Conv_Block, self).__init__()
self.conv = nn.Sequential(
nn.BatchNorm2d(ch_in),
nn.ReLU(inplace=True),
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
)
def forward(self, x):
x = self.conv(x)
return x
class dens_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(dens_block, self).__init__()#这三个相同吗????
self.conv1 = Conv_Block(ch_in,ch_out)
self.conv2 = Conv_Block(ch_out+ch_in, ch_out)
self.conv3 = Conv_Block(ch_out*2 + ch_in, ch_out)
def forward(self,input_tensor):
x1 = self.conv1(input_tensor)
add1 = torch.cat([x1,input_tensor],dim=1)
x2 = self.conv2(add1)
add2 =torch.cat([x1, input_tensor,x2], dim=1)
x3 = self.conv3(add2)
return x3
class Conv2D(nn.Module):
def __init__(self, ch_in, ch_out):
super(Conv2D, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class DenseU_Net(nn.Module):
def __init__(self, img_ch=3, output_ch=1):
super(DenseU_Net, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv0 = nn.Conv2d(img_ch,32,kernel_size=7,padding=3,stride=1)
self.Conv1 = dens_block(ch_in=32, ch_out=64)
self.Conv2 = dens_block(ch_in=64, ch_out=64)
self.Conv3 = dens_block(ch_in=64, ch_out=128)
self.Conv4 = conv_block(ch_in=128, ch_out=256)
#center
self.Conv5_1 = Conv2D(ch_in=256,ch_out=512)
self.Conv5_2 = Conv2D(ch_in=512,ch_out=512)
self.Drop5 = nn.Dropout(0.5)
self.Up6 = up_conv(512,512)
self.add6 = torch.cat
self.up6 = dens_block(512+256,256)
self.Up7 = up_conv(256, 256)
self.add7 = torch.cat
self.up7 = dens_block(256+128, 128)
self.Up8 = up_conv(128, 128)
self.add8 = torch.cat
self.up8 = dens_block(128+64, 64)
self.Up9 = up_conv(64, 64)
self.add9 = torch.cat
self.up9 = dens_block(64+64, 64)
self.conv10_1 = nn.Conv2d(64,32,7,1,3)
self.relu = nn.ReLU(inplace=True)
self.conv10_2 = nn.Conv2d(32,output_ch,3,1,1)
def forward(self, x):
x = self.Conv0(x)#256
down1 = self.Conv1(x)#256
pool1 = self.Maxpool(down1)#128
down2 = self.Conv2(pool1)#128
pool2 = self.Maxpool(down2)#64
down3 = self.Conv3(pool2)#64
pool3 = self.Maxpool(down3)#32
down4 = self.Conv4(pool3)#32
pool4 = self.Maxpool(down4)#16
conv5 = self.Conv5_1(pool4)#16
conv5 = self.Conv5_2(conv5)#16
drop5 = self.Drop5(conv5)#16
up6 = self.Up6(drop5)#32
# print(up6.shape)
# print(down4.shape)
add6 = self.add6([down4,up6],dim=1)
up6 = self.up6(add6)
up7 = self.Up7(up6)#64
add7 = self.add7([down3,up7],dim=1)
up7 = self.up7(add7)
up8 = self.Up8(up7)#128
add8 = self.add8([down2,up8],dim=1)
up8 = self.up8(add8)
up9 = self.Up9(up8)#256
add9 = self.add9([down1,up9],dim=1)
up9 = self.up9(add9)
conv10 = self.conv10_1(up9)
conv10 = self.relu(conv10)
conv10 = self.conv10_2(conv10)
return conv10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DenseU_Net(img_ch=3, output_ch=1).to(device)
print(model)
input_1 = torch.rand(1,3,256,256).to(device)
print(input_1.shape)
output = model(input_1)
print(output.shape)
我主要在DenseU-Net代码方面纠结捣鼓了3天(太菜了┭┮﹏┭┮)。第一天很简单的从网上找代码,找到一个用pytorch的DenseU-Net,写的很复杂,我没怎么看懂,它的密集连接不是封装的,用了很多的编程知识,在把数据集放上之后运行,发现它的准确率很低。之后我就想着和ResU-Net一样调用Pytorch现有的网络。考虑参数量我选择了DenseNet121,发现它的结构3-64-256-512-1024,如果对称来的话中间层通道数为2024,解码器1024-512-256-64-1。参数量很大,在尝试运行之后发现过拟合了,训练集上99%准确率,验证集准确率在97%左右徘徊,没有上升趋势,参数量为1亿。然后考虑把网络调小一点,可能这个时候把项目弄坏了,准确率在20%左右,不管是测试集还是验证集。考虑是不是模型调的太小,又重复实验几次,准确率还是在20%左右。放弃这个方案。之后又转入H-Dense U-net,还是看不懂,Keras转Pytorch失败。
之后看到一篇论文是把DenseU-Net用到视频分割好像是与SLAM有关大概,然后他的框架是keras,没有学过这个框架,就把他的代码变成Pytorch框架,运行20%准确率,突然想起来是不是不是模型架构的事情,是项目本身出了问题,就开始实验,U-Net的准确率也在20%,于是更换项目,实现Dense-Unet。
根据Dense121模型和【医学图像分割网络】之Res U-Net网络PyTorch复现编写的模型,就是太大了,过拟合,而且电脑8太行运行不起来。
"""
Dense121 + U-Net
"""
import torch
from torch import nn
import torchvision.models as models
import torch.nn.functional as F
from torchsummary import summary
class expansive_block(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels):
super(expansive_block, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(mid_channels),
nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels)
)
def forward(self, d, e=None):
d = F.interpolate(d, scale_factor=2, mode='bilinear', align_corners=True)
# concat
if e is not None:
cat = torch.cat([e, d], dim=1)
out = self.block(cat)
else:
out = self.block(d)
return out
def final_block(in_channels, out_channels):
block = nn.Sequential(
nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels, padding=1),
nn.ReLU(),
nn.BatchNorm2d(out_channels),
)
return block
class DenseUnet121_Unet(nn.Module):
def __init__(self, in_channel, out_channel, pretrained=False):
super(DenseUnet121_Unet, self).__init__()
self.densenet = models.densenet121(pretrained=pretrained)
self.layer0 = nn.Sequential(
self.densenet.features.conv0,
self.densenet.features.norm0
)
# Encode
self.denseblock1 = self.densenet.features.denseblock1
self.denseblock2 = self.densenet.features.denseblock2
self.denseblock3 = self.densenet.features.denseblock3
self.denseblock4 = self.densenet.features.denseblock4
self.transition1 = self.densenet.features.transition1
self.transition2 = self.densenet.features.transition2
self.transition3 = self.densenet.features.transition3
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
# Bottleneck
self.bottleneck = torch.nn.Sequential(
nn.Conv2d(in_channels=1024,out_channels=2048,kernel_size=(3, 3), padding=1),
nn.ReLU(),
nn.BatchNorm2d(2048),
nn.Conv2d(in_channels=2048, out_channels=2048,kernel_size=(3, 3), padding=1),
nn.ReLU(),
nn.BatchNorm2d(2048),
nn.MaxPool2d(kernel_size=(2, 2), stride=2)
)
# Decode
self.conv_decode4 = expansive_block(1024+2048, 1024, 1024)
self.conv_decode3 = expansive_block(1024+1024, 512, 512)
self.conv_decode2 = expansive_block(512+512, 256, 256)
self.conv_decode1 = expansive_block(256+256, 128, 128)
self.conv_decode0 = expansive_block(128, 64, 64)
self.final_layer = final_block(64, out_channel)
def forward(self, x):
x = self.layer0(x)
# Encode
encode_block1 = self.denseblock1(x)
encode_block2 = self.denseblock2(self.transition1(encode_block1))
encode_block3 = self.denseblock3(self.transition2(encode_block2))
encode_block4 = self.denseblock4(self.transition3(encode_block3))
# Bottleneck
bottleneck = self.bottleneck(encode_block4)
#
# # Decode
decode_block4 = self.conv_decode4(bottleneck, encode_block4)
decode_block3 = self.conv_decode3(decode_block4, encode_block3)
decode_block2 = self.conv_decode2(decode_block3, encode_block2)
decode_block1 = self.conv_decode1(decode_block2, encode_block1)
decode_block0 = self.conv_decode0(decode_block1)
# #
final_layer = self.final_layer(decode_block0)
print(encode_block1.shape)#([1, 256, 64, 64])
print(encode_block2.shape)#[1, 512, 32, 32]
print(encode_block3.shape)#[1, 1024, 16, 16]
print(encode_block4.shape)#[1, 1024, 8, 8]
print(bottleneck.shape)#[1, 2048, 4, 4]
print(decode_block4.shape)#[1, 1024, 8, 8]
print(decode_block3.shape)#
print(decode_block2.shape)
print(decode_block1.shape)
print(decode_block0.shape)
print(final_layer.shape)
return final_layer
flag = 0
if flag:
image = torch.rand(1, 3, 224, 224)
DenseUnet121_Unet = DenseUnet121_Unet(in_channel=3, out_channel=1)
mask = DenseUnet121_Unet(image)
print(mask.shape)
# 测试网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DenseUnet121_Unet(in_channel=3, out_channel=1, pretrained=False).to(device)
print(model)
input_1 = torch.rand(1,3,256,256).to(device)
print(input_1.shape)
output = model(input_1)
print(output.shape)