生成对抗学习

生成对抗学习

  • 自动编码器复习
  • 生成式对抗网络介绍
  • 训练判别器
  • 训练生成器
  • 数据集
  • 模型构建
    • 生成器
    • 判别器
  • 模型训练

自动编码器复习

核心目标:构建输入等于输出
用途:降维、特征提取、初始化深度网络
训练方式:梯度下降+反向传播

生成式对抗网络介绍

最小最大游戏(零和博弈)
游戏双方分别是生成器和判别器。生成器学习伪造数据,判别器学习判断数据的真实性。
为了胜利双方不断自我优化,各自提高生成能力和判别能力,最终以假乱真。

训练判别器

真实数据集中采样数据,并标记为1。
生成器随机采样数据,并标记为0。
锁定生成器不训练,反向传播训练判别器。

训练生成器

锁住判别器,不训练判别器,但提供反向传播梯度。
从生成器采样数据,进行反向传播。
反向传播训练生成器,生成器使用与判别器相反梯度进行训练,不训练判别器,仅仅提供梯度。

数据集

class Image_data(Dataset):
	def __init__(self,img_h=256,img_w=256,path,data_path,label_path,process):
		self.img_h=imgz_h
		self.img_w=img_w
		self.path=path
		self.data_path=data_path
		self.label_path=label_path
		self.process=process
		self.img_data=os.listdir(self.path+'/'+self.data_path)
	def __len__(self):
		return len(self.img_data)
	def __getitem__(self,item):
		image_name=self.img_data[item]
		label_name=image_name.split('.')[0]
        image_path=self.path+'/'+self.data_path+'/'+image_name
        label_path=self.path+'/'+self.label_path+'/'+labe_name+'.jpg'
        image=Image.open(image_path)
        label=Image.open(label_path)
        if self.process:
        	transforms_image=transforms.Compose([
        	transforms.Resize([self.img_h,self.img_w]),
        	transforms.Totensor(),
        	transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
        	image=transforms_image(image)
        	transforms_label=transforms.Compose([
        	transforms.Resize([self.img_h,self.img_w]),
        	transfoms.ToTensor()])
        	label=transforms_label(label)
        return image,label

模型构建

生成器

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block, self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=ch_out,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        return self.conv(x)

class conv_up(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_up, self).__init__()
        self.conv=nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        return self.conv(x)

class U_net(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(U_net, self).__init__()
        self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv1=conv_block(ch_in=ch_in,ch_out=32)
        self.conv2=conv_block(ch_in=32,ch_out=64)
        self.conv3=conv_block(ch_in=64,ch_out=128)
        self.conv4=conv_block(ch_in=128,ch_out=256)
        self.conv5=conv_block(ch_in=256,ch_out=512)


        self.up5=conv_up(ch_in=512,ch_out=256)
        self.conv_up5=conv_block(ch_in=512,ch_out=256)

        self.up4=conv_up(ch_in=256,ch_out=128)
        self.conv_up4=conv_block(ch_in=256,ch_out=128)

        self.up3=conv_up(ch_in=128,ch_out=64)
        self.conv_up3=conv_block(ch_in=128,ch_out=64)

        self.up2=conv_up(ch_in=64,ch_out=32)
        self.conv_up2=conv_block(ch_in=64,ch_out=32)

        self.conv1_1=nn.Conv2d(in_channels=32,out_channels=ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x1=self.conv1(x)
        x2=self.maxpool(x1)
        x2=self.conv2(x2)
        x3=self.maxpool(x2)
        x3=self.conv3(x3)
        x4=self.maxpool(x3)
        x4=self.conv4(x4)
        x5=self.maxpool(x4)
        x5=self.conv5(x5)

        d5=self.up5(x5)
        d5=torch.cat((x4,d5),dim=1)
        d5=self.conv_up5(d5)

        d4=self.up4(d5)
        d4=torch.cat((x3,d4),dim=1)
        d4=self.conv_up4(d4)

        d3=self.up3(d4)
        d3=torch.cat((x2,d3),dim=1)
        d3=self.conv_up3(d3)

        d2=self.up2(d3)
        d2=torch.cat((x1,d2),dim=1)
        d2=self.conv_up2(d2)

        d1=self.conv1_1(d2)
        d1=torch.sigmoid(d1)
        return d1

判别器

class CNN(nn.Module):
    def __init__(self,ch_in,num_class=1):
        super(CNN, self).__init__()
        ndf=32
        self.dis=nn.Sequential(
            conv_block(ch_in=ch_in,ch_out=ndf),
            nn.MaxPool2d(kernel_size=2,stride=2),

            conv_block(ch_in=ndf,ch_out=2*ndf),
            nn.MaxPool2d(kernel_size=2,stride=2),

            conv_block(ch_in=2*ndf,ch_out=4*ndf),
            nn.MaxPool2d(kernel_size=2,stride=2),

            conv_block(ch_in=4*ndf,ch_out=8*ndf),
            nn.MaxPool2d(kernel_size=2,stride=2),

            conv_block(ch_in=8*ndf,ch_out=16*ndf)
        )

        self.fc=nn.Sequential(
            nn.Linear(16*ndf,num_class),
            nn.Sigmoid(),
        )

        self.avg_pool=nn.AdaptiveAvgPool2d((1,1))

    def forward(self,x):
        out=self.dis(x)
        out=self.avg_pool(out)
        out=out.view(out.size(0),-1)
        out=self.fc(out)

        return out

模型训练

class Trainer(object):
	def __init__(self,ch_in=3,ch_out=3,epoch=50,batchsize=16,lr=0.005,dataset=None):
		self.ch_in=ch_in
		self.ch_out=ch_out
		self.epoch=epoch
		self.batchsize=batchsize
		self.lr=lr
		self.dataset_loader=Dataloader(dataset=dataset,batch_size=self.batch_size,shuffle=True)

		#生成器
		self.gen=U_net(ch_in=self.ch_in,ch_out=self.ch_out)
		self.gen_optimizer=torch.optim.Adam(self.gen.parameters(),lr=self.lr)
		self.gen_loss=nn.L1Loss()

		#判别器
		self.dis=CNN(ch_in=self.ch_in*2,num_class=1)
		self.dis_optimizer=torch.optim.Adam(self.dis.parameters(),lr=self.lr)
		self.dis_loss=nn.BCELoss()


	def set_requires_grad(self,nets,requires_grad):
		if not isinstance(nets,list):
			nets=[nets]
		for net in nets:
			if net is not None:
				for parm in net.parameters():
					parm.requires_grad=requires_grad
	
	def train(self):
		Tensor=torch.FloatTensor
		for epoch in range(self.epoch):
			epoch_gen_loss=0
			epoch_dis_loss=0
			for i,(bx,by) in enumerate(self.dataset_loader):
				one_label=Variable(Tensor(bx.size(0),1).fill_(1.0),requires_grad=False)
				zero_label=Variable(Tensor(bx.size(0),1).fill_(0),requires_grad=False)
				if i%2==0:
					#训练生成器
					self.set_requires_grad(self.dis,False)
					self.set_erquires_grad(self.gen,True)
					bx_gen=self.gen(bx)
					loss_rec=self.gen_loss(bx_gen,by)


					fake_ab=torch.cat([bx_gen,bx],dim=1)
					dis_fake=self.dis(fake_ab)
					loss_gen=self.gen_loss(dis_fake,one_label)
					loss_gen=loss_gen+100*loss_rec

					self.gen_optimizer.zero_grad()
					loss_gen.backward()
					self.gen_optimizer.step()
					epoch_gen_loss+=loss_gen.item()
					print('生成器损失',loss_gen.item())
				
				else:
					#训练判别器
					self.set_requires_grad(self.dis.False)
					self.set_requires_grad(self.gen,True)
					bx_gen=self.gen(bx)
					fake_ab=torch.cat([bx_gen,bx],dim=1)
					dis_fake=self.dis(fake_ab)

					real_ab=torch.cat([by,bx],dim=1)
					dis_real=self.dis(real_ab)
					
					loss_fake=self.gen_loss(dis_fake,zero_label)
					loss_real=self.gen_loss(dis_real,one_label)
					
					loss_dis=(loss_fake+loss_real)/2
					self.dis_optimizer.zero_grad()
					loss_dis.backward()
					self.dis_optimizer.step()
					epoch_dis_loss+=loss_dis.item()
					print('判别器损失',loss_dis.item())

					
		

你可能感兴趣的:(强化与提高,深度学习)