自动编码器的训练过程
学习目标:输入特征等于输出特征
训练结果:编码器的处处(压缩特征)是更抽象,更健壮的高级特征。
导入需要的包
import torch
from torch.utils.data import DataLoader,Dataset
from torch import nn
from torchvision import transforms
from PIL import Image
import os
from torchvision.utils import save_image
class Image_data(Dataset):
def __init__(self,img_h=256,img_w=256,path,data_path,label_path,process=True):
self.img_h=img_h
self.img_w=img_w
self.path=path
self.data_path=data_path
self.label_path=label_path
self.process=process
self.img_path=os.listdir(self.path+'/'+self.data_path)
def __len__(self):
return len(self.img_path)
def __getitem__(self,item):
img_name=self.img_path[item]
label_name=img_name.split('.')[0]
img_path=self.path+'/'self.data_path+'/'+img_name
label_path=self.path+'/'+self.label_path+'/'+label_name+'-profile.jpg'
image=Image.open(img_path)
if self.process:
transforms_image=transforms.Compose([
transforms.Resize([self.img_h,self.img_w]),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
image=transforms_image(image)
transforms_label=transforms.Compose([
transforms.Resize([self.img_h,self.img_w]),
transforms.ToTensor()])
label=transforms_label(label)
return image,label
class conv_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_block, self).__init__()
self.conv=nn.Sequential(
nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
return self.conv(x)
class conv_up(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_up, self).__init__()
self.conv=nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
return self.conv(x)
class U_net(nn.Module):
def __init__(self,ch_in,ch_out):
super(U_net, self).__init__()
self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2)
self.conv1=conv_block(ch_in=ch_in,ch_out=32)
self.conv2=conv_block(ch_in=32,ch_out=64)
self.conv3=conv_block(ch_in=64,ch_out=128)
self.conv4=conv_block(ch_in=128,ch_out=256)
self.conv5=conv_block(ch_in=256,ch_out=512)
self.up5=conv_up(ch_in=512,ch_out=256)
self.conv_up5=conv_block(ch_in=512,ch_out=256)
self.up4=conv_up(ch_in=256,ch_out=128)
self.conv_up4=conv_block(ch_in=256,ch_out=128)
self.up3=conv_up(ch_in=128,ch_out=64)
self.conv_up3=conv_block(ch_in=128,ch_out=64)
self.up2=conv_up(ch_in=64,ch_out=32)
self.conv_up2=conv_block(ch_in=64,ch_out=32)
self.conv1_1=nn.Conv2d(in_channels=32,out_channels=ch_out,kernel_size=1,stride=1,padding=0)
def forward(self,x):
x1=self.conv1(x)
x2=self.maxpool(x1)
x2=self.conv2(x2)
x3=self.maxpool(x2)
x3=self.conv3(x3)
x4=self.maxpool(x3)
x4=self.conv4(x4)
x5=self.maxpool(x4)
x5=self.conv5(x5)
d5=self.up5(x5)
d5=torch.cat((x4,d5),dim=1)
d5=self.conv_up5(d5)
d4=self.up4(d5)
d4=torch.cat((x3,d4),dim=1)
d4=self.conv_up4(d4)
d3=self.up3(d4)
d3=torch.cat((x2,d3),dim=1)
d3=self.conv_up3(d3)
d2=self.up2(d3)
d2=torch.cat((x1,d2),dim=1)
d2=self.conv_up2(d2)
d1=self.conv1_1(d2)
d1=torch.sigmoid(d1)
return d1
class Trainer(object):
def __init__(self,ch_in=3,ch_out=1,batch_size=16,epoch=50,lr=0.005,trainset=None):
self.ch_in=ch_in
self.ch_out=ch_out
self.batch_size=batch_size
self.epoch=epoch
self.lr=lr
self.trainset=trainset
self.train_dataloader=Dataloader(dataset=trainset,batch_size=self.batch_size,shuffle=True)
#加载模型
self.up_net=U_net(ch_in=ch_in,ch_out=ch_out)
#定义优化器和损失函数
self.optimizer=torch.nn.optim.Adam(self.up_net,parameters(),lr=self.lr)
self.loss=nn.BCELoss()
def train(self):
for epoch in range(self.epoch):
for i,(bx,by) in enumerate(self.train_data_loader):
bx_gen=self.up_net(bx)
loss=self.loss(bx_gen,by)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
print('epoch',epoch,'loss',loss)
torch.save(self.u_net.state_dict(),'u_net.pkl')
bx=bx.data[0]
bx_gen=bx_gen.data[0]
bx_gen=torch.where(bx_gen>0.5,torch.full_like(bx_gen,0),torch.full_like(bx_gen,1))
bx_gen=torch.zeros([3,256,256])+bx_gen
img=bx*bx_gen
img=torch.where(img==0,torch.full_like(img,255),img)
save_image(img,'../model2/1.png')
if __name__ == '__main__':
train_data=Image_data()
train=Trainer(trainset=train_data)
train.train()