使用Encoder-Decoder+Gan网络结构修复图像
E-D阶段用于学习图像特征生成待修补区域对应的预测图,使用GAN对抗学习来优化模型
针对联合损失和规则遮挡的encoder-decoder+GAN
自定义一个生成器网络G:Encoder-Decoder
过程:
(1)自定义一个类,继承自Module类,实现两个基本的函数,第一是构造函数__init__,第二个是层的逻辑运算函数,即前向计算函数forward函数
(2)在构造函数_init__中实现层的参数定义,比如Linear层的权重和偏置,Conv2d层的channels, kernel_size, stride=1,padding=1,bias=False
(3)在前向传播forward函数里面实现前向运算。
#定义生成器网络G 输入128*128大小被遮挡的图片,输出64*64大小的只有遮挡部位的图片
class _netG(nn.Module):
def __init__(self, opt): #一般在__init__中定义网络需要的操作算子,比如卷积、全连接算子等等
super(_netG, self).__init__()#初始化参数
self.ngpu = opt.ngpu
self.main = nn.Sequential(
#编码器 输入128*128的遮挡图,经过5次上采样卷积操作,
#input is (nc) x 128 x 128 输入=通道数*128*128
nn.Conv2d(opt.nc,opt.nef,4,2,1, bias=False), # kernel_size=4, stride=2, padding=1
nn.LeakyReLU(0.2, inplace=True),
#layer2输入的是nef*64*64—>64 x 32 x 32
nn.Conv2d(opt.nef,opt.nef,4,2,1, bias=False),
nn.BatchNorm2d(opt.nef),
nn.LeakyReLU(0.2, inplace=True),
#layer3 64 x 32 x 32—>128 x 16 x 16
nn.Conv2d(opt.nef,opt.nef*2,4,2,1, bias=False),
nn.BatchNorm2d(opt.nef*2),
nn.LeakyReLU(0.2, inplace=True), #relu中f=maxy(0,x),而leakyrelu中f=x>0?x:ax(a=栏目大)
#layer4:128 x 16 x 16—>256x 8 x 8
nn.Conv2d(opt.nef*2,opt.nef*4,4,2,1, bias=False),
nn.BatchNorm2d(opt.nef*4),
nn.LeakyReLU(0.2, inplace=True),
# layer5:256 x 8 x 8—>512x 4 x 4
nn.Conv2d(opt.nef*4,opt.nef*8,4,2,1, bias=False),
nn.BatchNorm2d(opt.nef*8),
nn.LeakyReLU(0.2, inplace=True),
# state size:(nef*8) x 4 x 4
nn.Conv2d(opt.nef*8,opt.nBottleneck,4, bias=False),
# tate size: (nBottleneck) x 1 x 1
nn.BatchNorm2d(opt.nBottleneck),
nn.LeakyReLU(0.2, inplace=True),
#解码器 上采样过程是5次逆卷积操作input 512*4*4->output 3*64*64
# input is Bottleneck, going into a convolution
nn.ConvTranspose2d(opt.nBottleneck, opt.ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(opt.ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(opt.ngf * 8, opt.ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(opt.ngf * 4, opt.ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(opt.ngf * 2, opt.ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(opt.ngf, opt.nc, 4, 2, 1, bias=False),#变成3通道,输出3*64*64
nn.Tanh()#激活函数,可以达到优化模型的效果
# state size. (nc) x 64 x 64
)
#上面的是将所有的层都放在了构造函数__init__里面,但是只是定义了一系列的层,各个层之间什么连接关系并没有,而是在forward里面实现所有层的连接关系,当然这里依然是顺序连接的。
#定义forward()前向传输,
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:#ngpu表示gpu的个数,当n>1使用并发处理
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output
class _netlocalD(nn.Module):
def __init__(self, opt):
super(_netlocalD, self).__init__()
self.ngpu = opt.ngpu
self.main = nn.Sequential(
#输入遮挡部分的真实图像64*64
# input is (nc) x 64 x 64=3*64*64
#layer1 3*64*64->64*32*32
nn.Conv2d(opt.nc, opt.ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32=64*32*32 ndf卷积核个数,也就是滤波器的个数
#layer2 64*32*32->128*16*16
nn.Conv2d(opt.ndf, opt.ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
#layer3 128*16*16->256*8*8
nn.Conv2d(opt.ndf * 2, opt.ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
#layer4 state size. (ndf*4) x 8 x 8
nn.Conv2d(opt.ndf * 4, opt.ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(opt.ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
#layer5 state size. (ndf*8) x 4 x 4
nn.Conv2d(opt.ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid() #sigmoid是激活函数的一种,它会将样本值映射到0到1之间。
)
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
else:
output = self.main(input)
return output.view(-1, 1)
import 所需要的模块
创建 ArgumentParser()对象
调用 add_argument()方法添加参数
—dataset 指定训练数据集
—dataroot 指定数据集下载路径或者已经存在的数据集路径
—workers 进行数据预处理及数据加载使用进程数
—batchSize 一次batch进入模型的图片数目
—imageSize 原始图片重采样进入模型前的大小
—nz 初始噪音向量的大小(Size of latent zz vector)
—ngf 生成网络中基础feature数目
—ndf 判别网络中基础feature数目
—netG 指定生成网络路径
—netD 指定判别网路径
—niter网络训练过程中epoch数目
—lr 初始学习率
—beta1 使用Adam优化算法中的β1β
-nef 第一个卷积层的滤波器数量
-overlapPred 步长(stride)小于卷积核的边长,出现卷积核与原始输入矩阵作用范围在区域上的重叠(overlap),一致时,不会出现重叠现象。
-nBottleneck编码器nBottleneck的数量
—cuda 指定使用GPU进行训练
—outf 模型输出图片的保存路径
—manualSeed 指定生成随机数的seed
-wtl2 L2损失函数的权重0.998
-wtlD 对抗损失的函数0.001
训练次数nither=25,学习速率lr=0.0002
目的:将数据集变成自己想要的格式和大小
if opt.dataset ='streetview':
dataset = dset.ImageFolder(root=opt.dataroot,
transform=transforms.Compose([ #组合多个transforms的操作
transforms.Scale(opt.imageSize), #调整到需要的大小
transforms.CenterCrop(opt.imageSize), #在图像中心区域进行裁剪
transforms.ToTensor(),#将对象转换为tensor,把灰度范围从0-255变换到0-1之间
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
#(tensor,平均值、标准差)把0-1变换到(-1,1)计算方式image=(image-mean)/std image=(image-0.5)/0.5
dataset = dset.ImageFolder(root=opt.dataroot, transform=transform)
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
#shuffle=True用于打乱数据集,每次都会以不同的顺序返回
##在netG和netD上调用自定义权重初始化,这里是对整个网络进行初始化定义
def weights_init(m):
classname = m.__class__.__name__#得到了网络层的名字
if classname.find('Conv') != -1:#使用了find函数,如果不存在返回值为-1,所以让其不等于-1
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
resume_epoch=0 #更新一次训练,当一个完整的数据集通过了神经网络一次并且返回了一次
netG = _netG(opt)
netG.apply(weights_init)#apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上
if opt.netG != '':
netG.load_state_dict(torch.load(opt.netG,map_location=lambda storage, location: storage)['state_dict'])#torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中。
resume_epoch = torch.load(opt.netG)['epoch']
print(netG)
netD = _netlocalD(opt)
netD.apply(weights_init)
if opt.netD != '':
netD.load_state_dict(torch.load(opt.netD,map_location=lambda storage, location: storage)['state_dict'])
resume_epoch = torch.load(opt.netD)['epoch']
print(netD)
criterion = nn.BCELoss() ##二元交叉熵损失函数BCELoss
criterionMSE = nn.MSELoss()#均方误差
input_real = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
input_cropped = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
label = torch.FloatTensor(opt.batchSize)#????
real_label = 1 #真标签为1
fake_label = 0 #加标签为0
real_center = torch.FloatTensor(opt.batchSize, 3,int(opt.imageSize/2), int(opt.imageSize/2))#真实的中间图片
设置优化器:神经网络训练时,采用梯度下降,更新权重参数,逐渐逼近最小的loss
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
#指定优化的参数(优化模型的参数,学习速率)
将真实图片和生成器生成的虚假图片也送入判别器进行判别,然后对抗训练判别器网络,使用对抗损失也就是我们的联合损失不断更新判别器
for epoch in range(resume_epoch,opt.niter):
for i, data in enumerate(dataloader, 0):
real_cpu, _ = data
real_center_cpu = real_cpu[:,:,int(opt.imageSize/4):int(opt.imageSize/4)+int(opt.imageSize/2),int(opt.imageSize/4):int(opt.imageSize/4)+int(opt.imageSize/2)]
batch_size = real_cpu.size(0)
input_real.resize_(real_cpu.size()).copy_(real_cpu)
input_cropped.resize_(real_cpu.size()).copy_(real_cpu)
real_center.resize_(real_center_cpu.size()).copy_(real_center_cpu)
input_cropped.data
[:,0,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*117.0/255.0 - 1.0
input_cropped.data
[:,1,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*104.0/255.0 - 1.0
input_cropped.data
[:,2,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*123.0/255.0 - 1.0
#训练真实数据
netD.zero_grad()#判别器优化器梯度全部降为0
#让D尽可能的把真图片判别为1
label.resize_(batch_size).fill_(real_label)#标签全部改为1,一开始判断真实图片
output = netD(real_center) #判别器输出
output=output.squeeze(dim=-1)
errD_real = criterion(output, label) #计算判断真实图片的损失值
errD_real.backward() #反向传播
D_x = output.data.mean()
#train with fake训练虚假数据
#让D尽可能把假图片判别为0
fake = netG(input_cropped) #生成假图
label.data.fill_(fake_label) #标签全部改为0,一开始假图片
output = netD(fake.detach()) #对一个批次假图片进行分类 ,detach()里面的才会计算到
output=output.squeeze(-1)
errD_fake = criterion(output, label) #计算判断假图片为假的损失值
errD_fake.backward() #反向传播
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake #判断真图片和判断假图片的损失值加和作为总损失
optimizerD.step() #优化判别器
固定判别器,训练生成器
netG.zero_grad()#生成器梯度全部降为0
#让D尽可能把G生成的假图判别为1
label.data.fill_(real_label) # fake labels are real for generator cost #标签全部改为1,一开始判断真实图片
output = netD(fake) #判别器输出,判别刚才生成的假图片
output=output.squeeze(-1) ##修改 修改 output多了一个维度,需要把最后一个维度squeeze掉
errG_D = criterion(output, label) #计算判断假图片为真的损失值
wtl2Matrix = real_center.clone()
wtl2Matrix.data.fill_(wtl2*overlapL2Weight)
wtl2Matrix.data[:,:,int(opt.overlapPred):int(opt.imageSize/2 - opt.overlapPred),int(opt.overlapPred):int(opt.imageSize/2 - opt.overlapPred)] = wtl2
#计算L2的误差值
errG_l2 = (fake-real_center).pow(2)
errG_l2 = errG_l2 * wtl2Matrix
errG_l2 = errG_l2.mean()
errG = (1-wtl2) * errG_D + wtl2 * errG_l2#判别器和生成器的损失之和作为总损失
errG.backward() #反向传播
D_G_z2 = output.data.mean()
optimizerG.step() #优化生成器
![第25次修复的图像](C:\Users\15600\Desktop\学习\context encoder\第25次修复的图像.PNG) print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
% (epoch, opt.niter, i, len(dataloader),
errD.item(), errG_D.item(),errG_l2.item(), D_x,D_G_z1, ))
#保存图像
if i % 100 == 0: #每100幅图像放在一张照片中
vutils.save_image(real_cpu,
'result/train/real/real_samples_epoch_%03d.png' % (epoch))
vutils.save_image(input_cropped.data,
'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
recon_image = input_cropped.clone()
recon_image.data
[:,:,int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2),int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2)] = fake.data
vutils.save_image(recon_image.data,
'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))
# do checkpointing检查点
torch.save({'epoch':epoch+1,
'state_dict':netG.state_dict()},
'model/netG_streetview.pth' )
torch.save({'epoch':epoch+1,
'state_dict':netD.state_dict()},
'model/netlocalD.pth' )
训练效果
nither=0
nither=24
nither在(0,25)之间的时候,很明显随着训练次数的增多,修复效果明显变好(当然原作者训练了250次,我电脑配置不行,跑不动啊,25次已经到尽头了)
测试集用来测试图片的修复效果,因此前面的模型定义以及图片处理方式和train.py里面的步骤一样,这里不再赘述,我们这里只展示输出的结果。
t = real_center - fake
l2 = np.mean(np.square(t))
l1 = np.mean(np.abs(t))
real_center = (real_center+1)*127.5
fake = (fake+1)*127.5
for i in range(opt.batchSize):
p = p + psnr(real_center[i].transpose(1,2,0) , fake[i].transpose(1,2,0))
print(l2)
print(l1)
print(p/opt.batchSize)
输出L2:均方损失
输出L1:对抗损失
输出P:峰值信噪比
下面是它的测试结果
与原论文相比,由于我们的训练次数太少,因此修复效果不是太好,但本文重点旨在复现整个论文代码实现的流程,重在学习!!!