论文阅读——U-Net: Convolutional Networks for Biomedical Image Segmentation pytorch论文复现

前言:UNet是发表在2015年的关于医学图像分割的一篇论文,论文地址点击这里。Unet网络结构也是采取的编码器-解码器架构,四次下采样(maxpooling),四次上采样(转置卷积),形成了U型结构,网络结构由下图所示:
论文阅读——U-Net: Convolutional Networks for Biomedical Image Segmentation pytorch论文复现_第1张图片
网上已经有很多关于这篇论文阅读的文章,可以自己搜索一下,也可以参考这篇文章,我感觉是比较详细的,这里我主要放一下复现这篇文章的代码(使用Pytorch框架实现)。

一、数据加载部分

首先,数据集链接,对数据的预处理见下面的代码:

#  data.py文件用来进行数据集的制作
import os
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms
transform=transforms.Compose([
    transforms.ToTensor()
])

class MyDataset(Dataset):
    def __init__(self,path):
        self.path=path
        self.name=os.listdir(os.path.join(path,'SegmentationClass')) # 拼接取出SegmentationClass文件夹下面的所有图片

    def __len__(self):
        return len(self.name)

    def __getitem__(self, index):
        segment_name=self.name[index]  #图片的名称
        segment_path=os.path.join(self.path,'SegmentationClass',segment_name)
        image_path=os.path.join(self.path,'JPEGImages',segment_name.replace('png','jpg'))
        segment_image=keep_image_size_open(segment_path)
        image=keep_image_size_open(image_path) # utils包下面的函数,用来处理图片的大小(将大小不一的图片变成同样的像素值便于之后的训练)
        return transform(image),transform(segment_image) # 变成Tensor

if __name__ == '__main__':
    data=MyDataset('D:\workspace\Pycharm\pytorch-unet-master\pytorch-unet-master\VOCtrainval_11-May-2012\VOCdevkit\VOC2012')
    print(data[0][0].shape)
    print(data[0][1].shape)

utils文件代码如下:

# utils.py文件保存一些工具函数
from PIL import Image
def keep_image_size_open(path, size=(256, 256)):
    #处理图片的大小
    img = Image.open(path)
    temp = max(img.size)
    mask = Image.new('RGB', (temp, temp), (0, 0, 0))
    mask.paste(img, (0, 0)) #将图片粘贴到左上角
    mask = mask.resize(size)
    return mask

二、Unet网络结构

根据上面论文里面Unet的网络架构图定义Unet的网络架构

#net.py定义网络结构
import torch
from torch import nn
from torch.nn import functional as F

class Conv_Block(nn.Module):
    # 定义一个卷积块,使用的Sequential实例
    def __init__(self,in_channel,out_channel):
        super(Conv_Block, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout2d(0.3),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)


class DownSample(nn.Module):
    # 进行下采样
    def __init__(self,channel):
        super(DownSample, self).__init__()
        self.layer=nn.Sequential(
            nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )
    def forward(self,x):
        return self.layer(x)


class UpSample(nn.Module):
    # 进行上采样
    def __init__(self,channel):
        super(UpSample, self).__init__()
        self.layer=nn.Conv2d(channel,channel//2,1,1)
    def forward(self,x,feature_map):
        up=F.interpolate(x,scale_factor=2,mode='nearest')
        out=self.layer(up)
        return torch.cat((out,feature_map),dim=1)


class UNet(nn.Module):
    def __init__(self):
    	#根据网络结构图组合上面定义好的模块
        super(UNet, self).__init__()
        self.c1=Conv_Block(3,64)
        self.d1=DownSample(64)
        self.c2=Conv_Block(64,128)
        self.d2=DownSample(128)
        self.c3=Conv_Block(128,256)
        self.d3=DownSample(256)
        self.c4=Conv_Block(256,512)
        self.d4=DownSample(512)
        self.c5=Conv_Block(512,1024)
        self.u1=UpSample(1024)
        self.c6=Conv_Block(1024,512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)
        self.out=nn.Conv2d(64,3,3,1,1)
        # 进行输出
        self.Th=nn.Sigmoid()
        # 使用Sigmoid()函数进行激活

    def forward(self,x):
    	# 计算前向传播
        R1=self.c1(x)
        R2=self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))
        O1=self.c6(self.u1(R5,R4))
        O2 = self.c7(self.u2(O1, R3))
        O3 = self.c8(self.u3(O2, R2))
        O4 = self.c9(self.u4(O3, R1))

        return self.Th(self.out(O4))

if __name__ == '__main__':
    x=torch.randn(2,3,256,256)
    net=UNet()
    print(net(x).shape)

三、进行训练

接下来就要根据网络结构进行训练

# train.py对定义好的网络进行训练
import os
from torch import nn,optim
import torch
from torch.utils.data import DataLoader
from data import *
from net import *
from torchvision.utils import save_image

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path='params/unet.pth' # 训练好的网络模型参数保存地址
data_path=r'data'	
save_path='train_image' #保存训练好的图片
if __name__ == '__main__':
    data_loader=DataLoader(MyDataset(data_path),batch_size=4,shuffle=True)
    net=UNet().to(device)
    # 判断之前是否进行过训练,选择是否加载训练权重文件
    if os.path.exists(weight_path):
        net.load_state_dict(torch.load(weight_path))
        print('successful load weight!')
    else:
        print('not successful load weight')

	# 定义优化和损失函数
    opt=optim.Adam(net.parameters())
    loss_fun=nn.BCELoss()
	# 进行训练
    epoch=1
    while True:
        for i,(image,segment_image) in enumerate(data_loader):
            image, segment_image=image.to(device),segment_image.to(device)

            out_image=net(image)
            train_loss=loss_fun(out_image,segment_image)

            opt.zero_grad()
            train_loss.backward()
            opt.step()

            if i%5==0:
                print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')

            if i%50==0:
                torch.save(net.state_dict(),weight_path)

            _image=image[0]
            _segment_image=segment_image[0]
            _out_image=out_image[0]

            img=torch.stack([_image,_segment_image,_out_image],dim=0)
            save_image(img,f'{save_path}/{i}.png')

        epoch+=1


四、训练结果

当loss收敛到0.06左右时
论文阅读——U-Net: Convolutional Networks for Biomedical Image Segmentation pytorch论文复现_第2张图片
图片分割效果:
论文阅读——U-Net: Convolutional Networks for Biomedical Image Segmentation pytorch论文复现_第3张图片

你可能感兴趣的:(论文阅读,Unet,卷积神经网络)