context encoder代码解读

context encoder代码解读

网络框架

使用Encoder-Decoder+Gan网络结构修复图像

E-D阶段用于学习图像特征生成待修补区域对应的预测图,使用GAN对抗学习来优化模型

针对联合损失和规则遮挡的encoder-decoder+GAN
context encoder代码解读_第1张图片

model.py

生成器

自定义一个生成器网络G:Encoder-Decoder

过程:

(1)自定义一个类,继承自Module类,实现两个基本的函数,第一是构造函数__init__,第二个是层的逻辑运算函数,即前向计算函数forward函数

(2)在构造函数_init__中实现层的参数定义,比如Linear层的权重和偏置,Conv2d层的channels, kernel_size, stride=1,padding=1,bias=False

(3)在前向传播forward函数里面实现前向运算。

#定义生成器网络G 输入128*128大小被遮挡的图片,输出64*64大小的只有遮挡部位的图片
class _netG(nn.Module):
	def __init__(self, opt): #一般在__init__中定义网络需要的操作算子,比如卷积、全连接算子等等
	super(_netG, self).__init__()#初始化参数
	self.ngpu = opt.ngpu
    self.main = nn.Sequential(
        #编码器  输入128*128的遮挡图,经过5次上采样卷积操作,
        #input is (nc) x 128 x 128 输入=通道数*128*128
    	nn.Conv2d(opt.nc,opt.nef,4,2,1, bias=False), # kernel_size=4, stride=2, padding=1
        nn.LeakyReLU(0.2, inplace=True),
        
        #layer2输入的是nef*64*64—>64 x 32 x 32
        nn.Conv2d(opt.nef,opt.nef,4,2,1, bias=False),
        nn.BatchNorm2d(opt.nef),
        nn.LeakyReLU(0.2, inplace=True),
        #layer3 64 x 32 x 32—>128 x 16 x 16
        nn.Conv2d(opt.nef,opt.nef*2,4,2,1, bias=False),
        nn.BatchNorm2d(opt.nef*2),
        nn.LeakyReLU(0.2, inplace=True), #relu中f=maxy(0,x),而leakyrelu中f=x>0?x:ax(a=栏目大)
        #layer4:128 x 16 x 16—>256x 8 x 8
        nn.Conv2d(opt.nef*2,opt.nef*4,4,2,1, bias=False),
        nn.BatchNorm2d(opt.nef*4),
        nn.LeakyReLU(0.2, inplace=True),
            
        # layer5:256 x 8 x 8—>512x 4 x 4
        nn.Conv2d(opt.nef*4,opt.nef*8,4,2,1, bias=False),
        nn.BatchNorm2d(opt.nef*8),
        nn.LeakyReLU(0.2, inplace=True),
            
        # state size:(nef*8) x 4 x 4
        nn.Conv2d(opt.nef*8,opt.nBottleneck,4, bias=False),
        # tate size: (nBottleneck) x 1 x 1
        nn.BatchNorm2d(opt.nBottleneck),
        nn.LeakyReLU(0.2, inplace=True),
        
        #解码器 上采样过程是5次逆卷积操作input 512*4*4->output 3*64*64
        # input is Bottleneck, going into a convolution
        nn.ConvTranspose2d(opt.nBottleneck, opt.ngf * 8, 4, 1, 0, bias=False),
        nn.BatchNorm2d(opt.ngf * 8),
        nn.ReLU(True),
            
        # state size. (ngf*8) x 4 x 4
        nn.ConvTranspose2d(opt.ngf * 8, opt.ngf * 4, 4, 2, 1, bias=False),
        nn.BatchNorm2d(opt.ngf * 4),
        nn.ReLU(True),
            
        # state size. (ngf*4) x 8 x 8
        nn.ConvTranspose2d(opt.ngf * 4, opt.ngf * 2, 4, 2, 1, bias=False),
        nn.BatchNorm2d(opt.ngf * 2),
        nn.ReLU(True),
            
        # state size. (ngf*2) x 16 x 16
        nn.ConvTranspose2d(opt.ngf * 2, opt.ngf, 4, 2, 1, bias=False),
        nn.BatchNorm2d(opt.ngf),
        nn.ReLU(True),
            
        # state size. (ngf) x 32 x 32
        nn.ConvTranspose2d(opt.ngf, opt.nc, 4, 2, 1, bias=False),#变成3通道,输出3*64*64
        nn.Tanh()#激活函数,可以达到优化模型的效果
        # state size. (nc) x 64 x 64
        )
        #上面的是将所有的层都放在了构造函数__init__里面,但是只是定义了一系列的层,各个层之间什么连接关系并没有,而是在forward里面实现所有层的连接关系,当然这里依然是顺序连接的。
         #定义forward()前向传输,
    def forward(self, input): 
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:#ngpu表示gpu的个数,当n>1使用并发处理
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output

判别器

class _netlocalD(nn.Module):
    def __init__(self, opt):
        super(_netlocalD, self).__init__()
        self.ngpu = opt.ngpu
        self.main = nn.Sequential(
            
            #输入遮挡部分的真实图像64*64
            # input is (nc) x 64 x 64=3*64*64
            #layer1 3*64*64->64*32*32
            nn.Conv2d(opt.nc, opt.ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size. (ndf) x 32 x 32=64*32*32 ndf卷积核个数,也就是滤波器的个数
            #layer2 64*32*32->128*16*16
            nn.Conv2d(opt.ndf, opt.ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(opt.ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # state size. (ndf*2) x 16 x 16
            #layer3 128*16*16->256*8*8
            nn.Conv2d(opt.ndf * 2, opt.ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(opt.ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            #layer4 state size. (ndf*4) x 8 x 8
             nn.Conv2d(opt.ndf * 4, opt.ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(opt.ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            #layer5 state size. (ndf*8) x 4 x 4
           
            nn.Conv2d(opt.ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid() #sigmoid是激活函数的一种,它会将样本值映射到0到1之间。
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1)

train.py

import 所需要的模块

创建 ArgumentParser()对象
调用 add_argument()方法添加参数

参数说明

—dataset 指定训练数据集
—dataroot 指定数据集下载路径或者已经存在的数据集路径
—workers 进行数据预处理及数据加载使用进程数
—batchSize 一次batch进入模型的图片数目
—imageSize 原始图片重采样进入模型前的大小
—nz  初始噪音向量的大小(Size of latent zz vector)
—ngf 生成网络中基础feature数目
—ndf 判别网络中基础feature数目 
—netG 指定生成网络路径
—netD 指定判别网路径
—niter网络训练过程中epoch数目
—lr  初始学习率
—beta1 使用Adam优化算法中的β1β
-nef  第一个卷积层的滤波器数量
-overlapPred 步长(stride)小于卷积核的边长,出现卷积核与原始输入矩阵作用范围在区域上的重叠(overlap),一致时,不会出现重叠现象。
-nBottleneck编码器nBottleneck的数量
—cuda 指定使用GPU进行训练
—outf 模型输出图片的保存路径
—manualSeed 指定生成随机数的seed
-wtl2 L2损失函数的权重0.998
-wtlD 对抗损失的函数0.001

训练次数nither=25,学习速率lr=0.0002

数据预处理

目的:将数据集变成自己想要的格式和大小

if opt.dataset ='streetview':
    dataset = dset.ImageFolder(root=opt.dataroot,
    transform=transforms.Compose([         #组合多个transforms的操作
    						transforms.Scale(opt.imageSize),       #调整到需要的大小
    						transforms.CenterCrop(opt.imageSize),  #在图像中心区域进行裁剪
    						transforms.ToTensor(),#将对象转换为tensor,把灰度范围从0-255变换到0-1之间
     						transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
#(tensor,平均值、标准差)把0-1变换到(-1,1)计算方式image=(image-mean)/std image=(image-0.5)/0.5
     dataset = dset.ImageFolder(root=opt.dataroot, transform=transform)
assert dataset   
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))
                                         #shuffle=True用于打乱数据集,每次都会以不同的顺序返回

初始化网络

##在netG和netD上调用自定义权重初始化,这里是对整个网络进行初始化定义
def weights_init(m):
    classname = m.__class__.__name__#得到了网络层的名字
    if classname.find('Conv') != -1:#使用了find函数,如果不存在返回值为-1,所以让其不等于-1
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        
resume_epoch=0 #更新一次训练,当一个完整的数据集通过了神经网络一次并且返回了一次

netG = _netG(opt) 
netG.apply(weights_init)#apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG,map_location=lambda storage, location: storage)['state_dict'])#torch.load_state_dict()函数就是用于将预训练的参数权重加载到新的模型之中。
    resume_epoch = torch.load(opt.netG)['epoch']
print(netG)

netD = _netlocalD(opt)
netD.apply(weights_init)
if opt.netD != '':
    netD.load_state_dict(torch.load(opt.netD,map_location=lambda storage, location: storage)['state_dict'])
    resume_epoch = torch.load(opt.netD)['epoch']
print(netD)

criterion = nn.BCELoss()   ##二元交叉熵损失函数BCELoss
criterionMSE = nn.MSELoss()#均方误差

input_real = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
input_cropped = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
label = torch.FloatTensor(opt.batchSize)#????
real_label = 1  #真标签为1
fake_label = 0  #加标签为0

real_center = torch.FloatTensor(opt.batchSize, 3,int(opt.imageSize/2), int(opt.imageSize/2))#真实的中间图片

优化器设置

设置优化器:神经网络训练时,采用梯度下降,更新权重参数,逐渐逼近最小的loss

optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
#指定优化的参数(优化模型的参数,学习速率)

训练判别器

将真实图片和生成器生成的虚假图片也送入判别器进行判别,然后对抗训练判别器网络,使用对抗损失也就是我们的联合损失不断更新判别器

for epoch in range(resume_epoch,opt.niter):
    for i, data in enumerate(dataloader, 0):
         real_cpu, _ = data
         real_center_cpu = real_cpu[:,:,int(opt.imageSize/4):int(opt.imageSize/4)+int(opt.imageSize/2),int(opt.imageSize/4):int(opt.imageSize/4)+int(opt.imageSize/2)]
        batch_size = real_cpu.size(0)
        input_real.resize_(real_cpu.size()).copy_(real_cpu)
        input_cropped.resize_(real_cpu.size()).copy_(real_cpu)
        real_center.resize_(real_center_cpu.size()).copy_(real_center_cpu)
        input_cropped.data
           [:,0,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*117.0/255.0 - 1.0
        input_cropped.data
        [:,1,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*104.0/255.0 - 1.0
         input_cropped.data
        [:,2,int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred),int(opt.imageSize/4+opt.overlapPred):int(opt.imageSize/4+opt.imageSize/2-opt.overlapPred)] = 2*123.0/255.0 - 1.0
        
        #训练真实数据
        netD.zero_grad()#判别器优化器梯度全部降为0
        #让D尽可能的把真图片判别为1
        label.resize_(batch_size).fill_(real_label)#标签全部改为1,一开始判断真实图片
        output = netD(real_center)        #判别器输出
        output=output.squeeze(dim=-1)                       
        errD_real = criterion(output, label) #计算判断真实图片的损失值
        errD_real.backward() #反向传播
        D_x = output.data.mean()
        
        #train with fake训练虚假数据
        #让D尽可能把假图片判别为0
        fake = netG(input_cropped)  #生成假图
        label.data.fill_(fake_label) #标签全部改为0,一开始假图片
        output = netD(fake.detach()) #对一个批次假图片进行分类 ,detach()里面的才会计算到
        output=output.squeeze(-1)     

        errD_fake = criterion(output, label) #计算判断假图片为假的损失值
        errD_fake.backward() #反向传播
        D_G_z1 = output.data.mean() 
        errD = errD_real + errD_fake #判断真图片和判断假图片的损失值加和作为总损失
        optimizerD.step() #优化判别器

训练生成器

固定判别器,训练生成器

        netG.zero_grad()#生成器梯度全部降为0
        #让D尽可能把G生成的假图判别为1
        label.data.fill_(real_label)  # fake labels are real for generator cost #标签全部改为1,一开始判断真实图片
        output = netD(fake)           #判别器输出,判别刚才生成的假图片
        output=output.squeeze(-1)     ##修改 修改 output多了一个维度,需要把最后一个维度squeeze掉                      

        errG_D = criterion(output, label) #计算判断假图片为真的损失值

损失函数

        wtl2Matrix = real_center.clone()
        wtl2Matrix.data.fill_(wtl2*overlapL2Weight)
        wtl2Matrix.data[:,:,int(opt.overlapPred):int(opt.imageSize/2 - opt.overlapPred),int(opt.overlapPred):int(opt.imageSize/2 - opt.overlapPred)] = wtl2
        
        #计算L2的误差值
        errG_l2 = (fake-real_center).pow(2)
        errG_l2 = errG_l2 * wtl2Matrix 
        errG_l2 = errG_l2.mean()       

        errG = (1-wtl2) * errG_D + wtl2 * errG_l2#判别器和生成器的损失之和作为总损失

        errG.backward() #反向传播

        D_G_z2 = output.data.mean() 
        optimizerG.step() #优化生成器

输出结果

![25次修复的图像](C:\Users\15600\Desktop\学习\context encoder\第25次修复的图像.PNG)		print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
              % (epoch, opt.niter, i, len(dataloader),
                 errD.item(), errG_D.item(),errG_l2.item(), D_x,D_G_z1, ))
         #保存图像
   		 if i % 100 == 0: #每100幅图像放在一张照片中
            vutils.save_image(real_cpu,
                    'result/train/real/real_samples_epoch_%03d.png' % (epoch))
            vutils.save_image(input_cropped.data,
                    'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
            recon_image = input_cropped.clone()
            recon_image.data
                [:,:,int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2),int(opt.imageSize/4):int(opt.imageSize/4+opt.imageSize/2)] = fake.data
            vutils.save_image(recon_image.data,
                    'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))
    
    # do checkpointing检查点
    torch.save({'epoch':epoch+1,
                'state_dict':netG.state_dict()},
                'model/netG_streetview.pth' )
    torch.save({'epoch':epoch+1,
                'state_dict':netD.state_dict()},
                'model/netlocalD.pth' )

训练效果
nither=0

nither=24
nither在(0,25)之间的时候,很明显随着训练次数的增多,修复效果明显变好(当然原作者训练了250次,我电脑配置不行,跑不动啊,25次已经到尽头了)

test.py

测试集用来测试图片的修复效果,因此前面的模型定义以及图片处理方式和train.py里面的步骤一样,这里不再赘述,我们这里只展示输出的结果。

t = real_center - fake
l2 = np.mean(np.square(t))
l1 = np.mean(np.abs(t))
real_center = (real_center+1)*127.5
fake = (fake+1)*127.5

for i in range(opt.batchSize):
    p = p + psnr(real_center[i].transpose(1,2,0) , fake[i].transpose(1,2,0))

print(l2)

print(l1)

print(p/opt.batchSize)

输出L2:均方损失
输出L1:对抗损失
输出P:峰值信噪比
下面是它的测试结果
context encoder代码解读_第2张图片
与原论文相比,由于我们的训练次数太少,因此修复效果不是太好,但本文重点旨在复现整个论文代码实现的流程,重在学习!!!

你可能感兴趣的:(图像修复,深度学习,pytorch,神经网络)