核心目标:构建输入等于输出
用途:降维、特征提取、初始化深度网络
训练方式:梯度下降+反向传播
最小最大游戏(零和博弈)
游戏双方分别是生成器和判别器。生成器学习伪造数据,判别器学习判断数据的真实性。
为了胜利双方不断自我优化,各自提高生成能力和判别能力,最终以假乱真。
真实数据集中采样数据,并标记为1。
生成器随机采样数据,并标记为0。
锁定生成器不训练,反向传播训练判别器。
锁住判别器,不训练判别器,但提供反向传播梯度。
从生成器采样数据,进行反向传播。
反向传播训练生成器,生成器使用与判别器相反梯度进行训练,不训练判别器,仅仅提供梯度。
class Image_data(Dataset):
def __init__(self,img_h=256,img_w=256,path,data_path,label_path,process):
self.img_h=imgz_h
self.img_w=img_w
self.path=path
self.data_path=data_path
self.label_path=label_path
self.process=process
self.img_data=os.listdir(self.path+'/'+self.data_path)
def __len__(self):
return len(self.img_data)
def __getitem__(self,item):
image_name=self.img_data[item]
label_name=image_name.split('.')[0]
image_path=self.path+'/'+self.data_path+'/'+image_name
label_path=self.path+'/'+self.label_path+'/'+labe_name+'.jpg'
image=Image.open(image_path)
label=Image.open(label_path)
if self.process:
transforms_image=transforms.Compose([
transforms.Resize([self.img_h,self.img_w]),
transforms.Totensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
image=transforms_image(image)
transforms_label=transforms.Compose([
transforms.Resize([self.img_h,self.img_w]),
transfoms.ToTensor()])
label=transforms_label(label)
return image,label
class conv_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_block, self).__init__()
self.conv=nn.Sequential(
nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
return self.conv(x)
class conv_up(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_up, self).__init__()
self.conv=nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
return self.conv(x)
class U_net(nn.Module):
def __init__(self,ch_in,ch_out):
super(U_net, self).__init__()
self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2)
self.conv1=conv_block(ch_in=ch_in,ch_out=32)
self.conv2=conv_block(ch_in=32,ch_out=64)
self.conv3=conv_block(ch_in=64,ch_out=128)
self.conv4=conv_block(ch_in=128,ch_out=256)
self.conv5=conv_block(ch_in=256,ch_out=512)
self.up5=conv_up(ch_in=512,ch_out=256)
self.conv_up5=conv_block(ch_in=512,ch_out=256)
self.up4=conv_up(ch_in=256,ch_out=128)
self.conv_up4=conv_block(ch_in=256,ch_out=128)
self.up3=conv_up(ch_in=128,ch_out=64)
self.conv_up3=conv_block(ch_in=128,ch_out=64)
self.up2=conv_up(ch_in=64,ch_out=32)
self.conv_up2=conv_block(ch_in=64,ch_out=32)
self.conv1_1=nn.Conv2d(in_channels=32,out_channels=ch_out,kernel_size=1,stride=1,padding=0)
def forward(self,x):
x1=self.conv1(x)
x2=self.maxpool(x1)
x2=self.conv2(x2)
x3=self.maxpool(x2)
x3=self.conv3(x3)
x4=self.maxpool(x3)
x4=self.conv4(x4)
x5=self.maxpool(x4)
x5=self.conv5(x5)
d5=self.up5(x5)
d5=torch.cat((x4,d5),dim=1)
d5=self.conv_up5(d5)
d4=self.up4(d5)
d4=torch.cat((x3,d4),dim=1)
d4=self.conv_up4(d4)
d3=self.up3(d4)
d3=torch.cat((x2,d3),dim=1)
d3=self.conv_up3(d3)
d2=self.up2(d3)
d2=torch.cat((x1,d2),dim=1)
d2=self.conv_up2(d2)
d1=self.conv1_1(d2)
d1=torch.sigmoid(d1)
return d1
class CNN(nn.Module):
def __init__(self,ch_in,num_class=1):
super(CNN, self).__init__()
ndf=32
self.dis=nn.Sequential(
conv_block(ch_in=ch_in,ch_out=ndf),
nn.MaxPool2d(kernel_size=2,stride=2),
conv_block(ch_in=ndf,ch_out=2*ndf),
nn.MaxPool2d(kernel_size=2,stride=2),
conv_block(ch_in=2*ndf,ch_out=4*ndf),
nn.MaxPool2d(kernel_size=2,stride=2),
conv_block(ch_in=4*ndf,ch_out=8*ndf),
nn.MaxPool2d(kernel_size=2,stride=2),
conv_block(ch_in=8*ndf,ch_out=16*ndf)
)
self.fc=nn.Sequential(
nn.Linear(16*ndf,num_class),
nn.Sigmoid(),
)
self.avg_pool=nn.AdaptiveAvgPool2d((1,1))
def forward(self,x):
out=self.dis(x)
out=self.avg_pool(out)
out=out.view(out.size(0),-1)
out=self.fc(out)
return out
class Trainer(object):
def __init__(self,ch_in=3,ch_out=3,epoch=50,batchsize=16,lr=0.005,dataset=None):
self.ch_in=ch_in
self.ch_out=ch_out
self.epoch=epoch
self.batchsize=batchsize
self.lr=lr
self.dataset_loader=Dataloader(dataset=dataset,batch_size=self.batch_size,shuffle=True)
#生成器
self.gen=U_net(ch_in=self.ch_in,ch_out=self.ch_out)
self.gen_optimizer=torch.optim.Adam(self.gen.parameters(),lr=self.lr)
self.gen_loss=nn.L1Loss()
#判别器
self.dis=CNN(ch_in=self.ch_in*2,num_class=1)
self.dis_optimizer=torch.optim.Adam(self.dis.parameters(),lr=self.lr)
self.dis_loss=nn.BCELoss()
def set_requires_grad(self,nets,requires_grad):
if not isinstance(nets,list):
nets=[nets]
for net in nets:
if net is not None:
for parm in net.parameters():
parm.requires_grad=requires_grad
def train(self):
Tensor=torch.FloatTensor
for epoch in range(self.epoch):
epoch_gen_loss=0
epoch_dis_loss=0
for i,(bx,by) in enumerate(self.dataset_loader):
one_label=Variable(Tensor(bx.size(0),1).fill_(1.0),requires_grad=False)
zero_label=Variable(Tensor(bx.size(0),1).fill_(0),requires_grad=False)
if i%2==0:
#训练生成器
self.set_requires_grad(self.dis,False)
self.set_erquires_grad(self.gen,True)
bx_gen=self.gen(bx)
loss_rec=self.gen_loss(bx_gen,by)
fake_ab=torch.cat([bx_gen,bx],dim=1)
dis_fake=self.dis(fake_ab)
loss_gen=self.gen_loss(dis_fake,one_label)
loss_gen=loss_gen+100*loss_rec
self.gen_optimizer.zero_grad()
loss_gen.backward()
self.gen_optimizer.step()
epoch_gen_loss+=loss_gen.item()
print('生成器损失',loss_gen.item())
else:
#训练判别器
self.set_requires_grad(self.dis.False)
self.set_requires_grad(self.gen,True)
bx_gen=self.gen(bx)
fake_ab=torch.cat([bx_gen,bx],dim=1)
dis_fake=self.dis(fake_ab)
real_ab=torch.cat([by,bx],dim=1)
dis_real=self.dis(real_ab)
loss_fake=self.gen_loss(dis_fake,zero_label)
loss_real=self.gen_loss(dis_real,one_label)
loss_dis=(loss_fake+loss_real)/2
self.dis_optimizer.zero_grad()
loss_dis.backward()
self.dis_optimizer.step()
epoch_dis_loss+=loss_dis.item()
print('判别器损失',loss_dis.item())