图像分割应用 Unet

图像分割应用Unet

  • 自动编码器复习
  • 图像切割应用
  • 加载数据
  • 模型构建
  • 训练

自动编码器复习

自动编码器的训练过程
学习目标:输入特征等于输出特征
训练结果:编码器的处处(压缩特征)是更抽象,更健壮的高级特征。

图像切割应用

导入需要的包

import torch
from torch.utils.data import DataLoader,Dataset
from torch import nn
from torchvision import transforms
from PIL import Image
import os
from torchvision.utils import save_image

加载数据

class Image_data(Dataset):
	def __init__(self,img_h=256,img_w=256,path,data_path,label_path,process=True):
		self.img_h=img_h
		self.img_w=img_w
		self.path=path
		self.data_path=data_path
		self.label_path=label_path
		self.process=process
		self.img_path=os.listdir(self.path+'/'+self.data_path)
	def __len__(self):
		return len(self.img_path)
	def __getitem__(self,item):
		img_name=self.img_path[item]
		label_name=img_name.split('.')[0]
		img_path=self.path+'/'self.data_path+'/'+img_name
		label_path=self.path+'/'+self.label_path+'/'+label_name+'-profile.jpg'
		image=Image.open(img_path)
		if self.process:
			transforms_image=transforms.Compose([
			transforms.Resize([self.img_h,self.img_w]),
			transforms.ToTensor(),
			transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
			image=transforms_image(image)
			transforms_label=transforms.Compose([
			transforms.Resize([self.img_h,self.img_w]),
			transforms.ToTensor()])
			label=transforms_label(label)
		return image,label
		

模型构建

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block, self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        return self.conv(x)
class conv_up(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_up, self).__init__()
        self.conv=nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        return self.conv(x)
class U_net(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(U_net, self).__init__()
        self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv1=conv_block(ch_in=ch_in,ch_out=32)
        self.conv2=conv_block(ch_in=32,ch_out=64)
        self.conv3=conv_block(ch_in=64,ch_out=128)
        self.conv4=conv_block(ch_in=128,ch_out=256)
        self.conv5=conv_block(ch_in=256,ch_out=512)


        self.up5=conv_up(ch_in=512,ch_out=256)
        self.conv_up5=conv_block(ch_in=512,ch_out=256)

        self.up4=conv_up(ch_in=256,ch_out=128)
        self.conv_up4=conv_block(ch_in=256,ch_out=128)

        self.up3=conv_up(ch_in=128,ch_out=64)
        self.conv_up3=conv_block(ch_in=128,ch_out=64)

        self.up2=conv_up(ch_in=64,ch_out=32)
        self.conv_up2=conv_block(ch_in=64,ch_out=32)

        self.conv1_1=nn.Conv2d(in_channels=32,out_channels=ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x1=self.conv1(x)
        x2=self.maxpool(x1)
        x2=self.conv2(x2)
        x3=self.maxpool(x2)
        x3=self.conv3(x3)
        x4=self.maxpool(x3)
        x4=self.conv4(x4)
        x5=self.maxpool(x4)
        x5=self.conv5(x5)

        d5=self.up5(x5)
        d5=torch.cat((x4,d5),dim=1)
        d5=self.conv_up5(d5)

        d4=self.up4(d5)
        d4=torch.cat((x3,d4),dim=1)
        d4=self.conv_up4(d4)

        d3=self.up3(d4)
        d3=torch.cat((x2,d3),dim=1)
        d3=self.conv_up3(d3)

        d2=self.up2(d3)
        d2=torch.cat((x1,d2),dim=1)
        d2=self.conv_up2(d2)

        d1=self.conv1_1(d2)
        d1=torch.sigmoid(d1)
        return d1

训练

class Trainer(object):
	def __init__(self,ch_in=3,ch_out=1,batch_size=16,epoch=50,lr=0.005,trainset=None):
		self.ch_in=ch_in
		self.ch_out=ch_out
		self.batch_size=batch_size
		self.epoch=epoch
		self.lr=lr
		self.trainset=trainset
		self.train_dataloader=Dataloader(dataset=trainset,batch_size=self.batch_size,shuffle=True)
		#加载模型
		self.up_net=U_net(ch_in=ch_in,ch_out=ch_out)
		#定义优化器和损失函数
		self.optimizer=torch.nn.optim.Adam(self.up_net,parameters(),lr=self.lr)
		self.loss=nn.BCELoss()
	def train(self):
		for epoch in range(self.epoch):
			for i,(bx,by) in enumerate(self.train_data_loader):
				bx_gen=self.up_net(bx)
				loss=self.loss(bx_gen,by)
				self.optimizer.zero_grad()
				loss.backward()
				self.optimizer.step()
				print('epoch',epoch,'loss',loss)
				torch.save(self.u_net.state_dict(),'u_net.pkl')
				bx=bx.data[0]
				bx_gen=bx_gen.data[0]
				bx_gen=torch.where(bx_gen>0.5,torch.full_like(bx_gen,0),torch.full_like(bx_gen,1))
				bx_gen=torch.zeros([3,256,256])+bx_gen
				img=bx*bx_gen
				img=torch.where(img==0,torch.full_like(img,255),img)
				save_image(img,'../model2/1.png')
				
if __name__ == '__main__':
    train_data=Image_data()
    train=Trainer(trainset=train_data)
    train.train()
				
		

你可能感兴趣的:(强化与提高,计算机视觉,人工智能)