基于Mxnet实现GAN-CycleGAN【附部分源码】

文章目录

  • 前言
  • 一、CycleGAN是什么
  • 二、代码实现
    • 1.引入库
    • 2.网络构建
    • 3.数据加载器
    • 4.模型训练
      • 1.优化器设置
      • 2.损失函数定义
      • 3.循环训练
      • 4.模型保存
    • 5.模型预测
  • 三、函数主入口
  • 四、训练效果展示


前言

本文基于Mxnet实现CycleGAN


一、CycleGAN是什么

CycleGAN图像翻译模型,由两个生成网络和两个判别网络组成,通过非成对的图片将某一类图片转换成另外一类图片,可用于风格迁移

  • 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
  • 机器学习的模型可大体分为两类,生成模型(Generative Model)和判别模型(Discriminative Model)。判别模型需要输入变量 ,通过某种模型来预测 。生成模型是给定某种隐含信息,来随机产生观测数据。举个简单的例子:
  • 生成模型:给一系列猫的图片,生成一张新的猫咪(不在数据集里)
  • 判别模型:给定一张图,判断这张图里的动物是猫还是狗

二、代码实现

1.引入库

import random, os, cv2, time
import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
from mxnet import gluon, image, autograd
from mxnet.gluon.data.vision import transforms
from mxnet.base import numeric_types
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from mxboard import SummaryWriter

2.网络构建

def define_G(output_nc, ngf, which_model_netG, use_dropout=False):
    if which_model_netG == 'resnet_9blocks':
        netG = ResnetGenerator(output_nc, ngf, use_dropout=use_dropout, n_blocks=9)
    elif which_model_netG == 'resnet_6blocks':
        netG = ResnetGenerator(output_nc, ngf, use_dropout=use_dropout, n_blocks=6)
    elif which_model_netG == 'unet_128':
        netG = UnetGenerator(output_nc, 7, ngf, use_dropout=use_dropout)
    elif which_model_netG == 'unet_256':
        netG = UnetGenerator(output_nc, 8, ngf, use_dropout=use_dropout)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % opt.which_model_netG)

    return netG

def define_D(ndf, which_model_netD, n_layers_D=3, use_sigmoid=False):
    if which_model_netD == 'basic':
        netD = NLayerDiscriminator(ndf, n_layers=3, use_sigmoid=use_sigmoid)
    elif which_model_netD == 'n_layers':
        netD = NLayerDiscriminator(ndf, n_layers_D, use_sigmoid=use_sigmoid)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD)
    return netD

3.数据加载器

class DataSet(gluon.data.Dataset):
    def __init__(self,DataDir_A, DataDir_B, transform):
        self.A_paths = [os.path.join(DataDir_A,f) for f in os.listdir(DataDir_A)]
        self.B_paths = [os.path.join(DataDir_B,f) for f in os.listdir(DataDir_B)]
        self.A_paths = sorted(self.A_paths)
        self.B_paths = sorted(self.B_paths)
        self.A_size = len(self.A_paths)
        self.B_size = len(self.B_paths)
        self.transform = transform

    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_size]
        B_path = self.B_paths[index % self.B_size]
        A_img = image.imread(A_path)
        B_img = image.imread(B_path)
        A = self.transform(A_img)
        B = self.transform(B_img)
        return A, B

    def __len__(self):
        return max(self.A_size, self.B_size)

4.模型训练

1.优化器设置

optimizer_GA = gluon.Trainer(self.netG_A.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_GB = gluon.Trainer(self.netG_B.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_DA = gluon.Trainer(self.netD_A.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_DB = gluon.Trainer(self.netD_B.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')

2.损失函数定义

cyc_loss = gluon.loss.L1Loss()

3.循环训练

for i, (real_A, real_B) in enumerate(self.data_loader):
    real_A = gluon.utils.split_and_load(real_A, ctx_list=self.ctx, batch_axis=0)
    real_B = gluon.utils.split_and_load(real_B, ctx_list=self.ctx, batch_axis=0)
    loss_G_list = []
    loss_D_A_list = []
    loss_D_B_list = []
    fake_A_list = []
    fake_B_list = []
    losses_log.reset()
    with autograd.record():
        for A,B in zip(real_A,real_B):
            fake_B = self.netG_A(A)
            rec_A = self.netG_B(fake_B)
            fake_A = self.netG_B(B)
            rec_B = self.netG_A(fake_A)

            idt_A = self.netG_A(B)
            loss_idt_A = cyc_loss(idt_A,B) * 10.0 * 0.5
            idt_B = self.netG_B(A)
            loss_idt_B = cyc_loss(idt_B,A) * 10.0 * 0.5

            loss_G_A = self.gan_loss(self.netD_A(fake_B),True)
            loss_G_B = self.gan_loss(self.netD_B(fake_A),True)
            loss_cycle_A = cyc_loss(rec_A,A) * 10.0
            loss_cycle_B = cyc_loss(rec_B,B) * 10.0
            loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B

            loss_G_list.append(loss_G)
            fake_A_list.append(fake_A)
            fake_B_list.append(fake_B)
            losses_log.add(loss_G_A=loss_G_A, loss_cycle_A=loss_cycle_A, loss_idt_A=loss_idt_A,loss_G_B=loss_G_B,
                        loss_cycle_B=loss_cycle_B, loss_idt_B=loss_idt_B,real_A=A, fake_B=fake_B, rec_A=rec_A,
                        idt_A=idt_A, real_B=B, fake_A=fake_A, rec_B=rec_B,idt_B=idt_B)
        autograd.backward(loss_G_list)
    optimizer_GA.step(self.batch_size)
    optimizer_GB.step(self.batch_size)
    with autograd.record():
        for A,B,fake_A,fake_B in zip(real_A,real_B,fake_A_list,fake_B_list):
            fake_B_tmp = fake_B_pool.query(fake_B)
            pred_real = self.netD_A(B)
            loss_D_real = self.gan_loss(pred_real,True)
            pred_fake = self.netD_A(fake_B_tmp.detach())
            loss_D_fake = self.gan_loss(pred_fake, False)
            loss_D_A = (loss_D_real + loss_D_fake) * 0.5
            loss_D_A_list.append(loss_D_A)

            fake_A_tmp = fake_A_pool.query(fake_A)
            pred_real = self.netD_B(A)
            loss_D_real = self.gan_loss(pred_real, True)
            pred_fake = self.netD_B(fake_A_tmp.detach())
            loss_D_fake = self.gan_loss(pred_fake,False)
            loss_D_B = (loss_D_real + loss_D_fake) * 0.5
            loss_D_B_list.append(loss_D_B)
            losses_log.add(loss_D_A=loss_D_A,loss_D_B=loss_D_B)
        autograd.backward(loss_D_A_list + loss_D_B_list)
    optimizer_DA.step(self.batch_size)
    optimizer_DB.step(self.batch_size)
    if ((epoch-1) * len(self.data_loader) + i) % 1 == 0 and self.sw is not None:
        plot_loss(losses_log, (epoch-1) * len(self.data_loader) + i,epoch,i, self.sw)
        plot_img(losses_log, self.sw)

4.模型保存

self.netG_A.save_parameters(os.path.join(ModelPath, 'netG_A.dat'))
self.netG_B.save_parameters(os.path.join(ModelPath, 'netG_B.dat'))
self.netD_A.save_parameters(os.path.join(ModelPath, 'netD_A.dat'))
self.netD_B.save_parameters(os.path.join(ModelPath, 'netD_B.dat'))

5.模型预测

 def predict(self,cv_img,ATOB=True):
    img_origin = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
    start_time = time.time()
    img = nd.array(img_origin)
    img = self.transform_fn(img)
    img = img.expand_dims(0).as_in_context(self.ctx)
    with autograd.record():
        if ATOB:
            output = self.netG_A(img)
        else:
            output = self.netG_B(img)
        predict = mx.nd.squeeze(output)
        predict = ((predict.transpose([1,2,0]).asnumpy() * 0.5 + 0.5) * 255).clip(0, 255).astype('uint8')
    res_image = cv2.cvtColor(predict,cv2.COLOR_BGR2RGB)
    result_value = {
        "image_result": res_image,
        "time": (time.time() - start_time) * 1000
    }
    return result_value

三、函数主入口

本人的代码调用比较简单

if __name__ == '__main__':
    ctu = Ctu_CycleGan(USEGPU='0',image_size=256)
    ctu.InitModel(DataDir_A='D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/trainA',
                  DataDir_B='D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/trainB', 
                  channels=3,batch_size = 1,num_workers = 0, channels_rate=0.5)
    ctu.train(TrainNum=300, learning_rate=0.0001,lr_decay_epoch='50,100,150,200',lr_decay = 0.9,ModelPath='./Model', logDir = './logs')


    ctu = Ctu_CycleGan(USEGPU='0',image_size=256)
    ctu.LoadModel(ModelPath=['./Model/netG_A.dat','./Model/netG_B.dat','./Model/netD_A.dat','./Model/netD_B.dat'])
    cv2.namedWindow("origin", 0)
    cv2.resizeWindow("origin", 640, 480)
    cv2.namedWindow("result", 0)
    cv2.resizeWindow("result", 640, 480)
    for root, dirs, files in os.walk(r'D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/testA'):
        for f in files:
            img_cv = cv2.imread(os.path.join(root, f))
            if img_cv is None:
                continue
            res = ctu.predict(img_cv,ATOB=True)
            print("耗时:" + str(res['time']) + ' ms')
            cv2.imshow("origin", img_cv)
            cv2.imshow("result", res['image_result'])
            cv2.waitKey()

四、训练效果展示

你可能感兴趣的:(深度学习-mxnet,mxnet,生成对抗网络,深度学习)