本文基于Mxnet实现CycleGAN
CycleGAN图像翻译模型,由两个生成网络和两个判别网络组成,通过非成对的图片将某一类图片转换成另外一类图片,可用于风格迁移
import random, os, cv2, time
import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
from mxnet import gluon, image, autograd
from mxnet.gluon.data.vision import transforms
from mxnet.base import numeric_types
from mxnet.gluon.data import DataLoader
from mxnet.gluon import nn
from mxboard import SummaryWriter
def define_G(output_nc, ngf, which_model_netG, use_dropout=False):
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(output_nc, ngf, use_dropout=use_dropout, n_blocks=9)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(output_nc, ngf, use_dropout=use_dropout, n_blocks=6)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(output_nc, 7, ngf, use_dropout=use_dropout)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(output_nc, 8, ngf, use_dropout=use_dropout)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % opt.which_model_netG)
return netG
def define_D(ndf, which_model_netD, n_layers_D=3, use_sigmoid=False):
if which_model_netD == 'basic':
netD = NLayerDiscriminator(ndf, n_layers=3, use_sigmoid=use_sigmoid)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(ndf, n_layers_D, use_sigmoid=use_sigmoid)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD)
return netD
class DataSet(gluon.data.Dataset):
def __init__(self,DataDir_A, DataDir_B, transform):
self.A_paths = [os.path.join(DataDir_A,f) for f in os.listdir(DataDir_A)]
self.B_paths = [os.path.join(DataDir_B,f) for f in os.listdir(DataDir_B)]
self.A_paths = sorted(self.A_paths)
self.B_paths = sorted(self.B_paths)
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
self.transform = transform
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
B_path = self.B_paths[index % self.B_size]
A_img = image.imread(A_path)
B_img = image.imread(B_path)
A = self.transform(A_img)
B = self.transform(B_img)
return A, B
def __len__(self):
return max(self.A_size, self.B_size)
optimizer_GA = gluon.Trainer(self.netG_A.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_GB = gluon.Trainer(self.netG_B.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_DA = gluon.Trainer(self.netD_A.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
optimizer_DB = gluon.Trainer(self.netD_B.collect_params(), 'adam', {'learning_rate': learning_rate,'beta1':0.5},kvstore='local')
cyc_loss = gluon.loss.L1Loss()
for i, (real_A, real_B) in enumerate(self.data_loader):
real_A = gluon.utils.split_and_load(real_A, ctx_list=self.ctx, batch_axis=0)
real_B = gluon.utils.split_and_load(real_B, ctx_list=self.ctx, batch_axis=0)
loss_G_list = []
loss_D_A_list = []
loss_D_B_list = []
fake_A_list = []
fake_B_list = []
losses_log.reset()
with autograd.record():
for A,B in zip(real_A,real_B):
fake_B = self.netG_A(A)
rec_A = self.netG_B(fake_B)
fake_A = self.netG_B(B)
rec_B = self.netG_A(fake_A)
idt_A = self.netG_A(B)
loss_idt_A = cyc_loss(idt_A,B) * 10.0 * 0.5
idt_B = self.netG_B(A)
loss_idt_B = cyc_loss(idt_B,A) * 10.0 * 0.5
loss_G_A = self.gan_loss(self.netD_A(fake_B),True)
loss_G_B = self.gan_loss(self.netD_B(fake_A),True)
loss_cycle_A = cyc_loss(rec_A,A) * 10.0
loss_cycle_B = cyc_loss(rec_B,B) * 10.0
loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
loss_G_list.append(loss_G)
fake_A_list.append(fake_A)
fake_B_list.append(fake_B)
losses_log.add(loss_G_A=loss_G_A, loss_cycle_A=loss_cycle_A, loss_idt_A=loss_idt_A,loss_G_B=loss_G_B,
loss_cycle_B=loss_cycle_B, loss_idt_B=loss_idt_B,real_A=A, fake_B=fake_B, rec_A=rec_A,
idt_A=idt_A, real_B=B, fake_A=fake_A, rec_B=rec_B,idt_B=idt_B)
autograd.backward(loss_G_list)
optimizer_GA.step(self.batch_size)
optimizer_GB.step(self.batch_size)
with autograd.record():
for A,B,fake_A,fake_B in zip(real_A,real_B,fake_A_list,fake_B_list):
fake_B_tmp = fake_B_pool.query(fake_B)
pred_real = self.netD_A(B)
loss_D_real = self.gan_loss(pred_real,True)
pred_fake = self.netD_A(fake_B_tmp.detach())
loss_D_fake = self.gan_loss(pred_fake, False)
loss_D_A = (loss_D_real + loss_D_fake) * 0.5
loss_D_A_list.append(loss_D_A)
fake_A_tmp = fake_A_pool.query(fake_A)
pred_real = self.netD_B(A)
loss_D_real = self.gan_loss(pred_real, True)
pred_fake = self.netD_B(fake_A_tmp.detach())
loss_D_fake = self.gan_loss(pred_fake,False)
loss_D_B = (loss_D_real + loss_D_fake) * 0.5
loss_D_B_list.append(loss_D_B)
losses_log.add(loss_D_A=loss_D_A,loss_D_B=loss_D_B)
autograd.backward(loss_D_A_list + loss_D_B_list)
optimizer_DA.step(self.batch_size)
optimizer_DB.step(self.batch_size)
if ((epoch-1) * len(self.data_loader) + i) % 1 == 0 and self.sw is not None:
plot_loss(losses_log, (epoch-1) * len(self.data_loader) + i,epoch,i, self.sw)
plot_img(losses_log, self.sw)
self.netG_A.save_parameters(os.path.join(ModelPath, 'netG_A.dat'))
self.netG_B.save_parameters(os.path.join(ModelPath, 'netG_B.dat'))
self.netD_A.save_parameters(os.path.join(ModelPath, 'netD_A.dat'))
self.netD_B.save_parameters(os.path.join(ModelPath, 'netD_B.dat'))
def predict(self,cv_img,ATOB=True):
img_origin = cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)
start_time = time.time()
img = nd.array(img_origin)
img = self.transform_fn(img)
img = img.expand_dims(0).as_in_context(self.ctx)
with autograd.record():
if ATOB:
output = self.netG_A(img)
else:
output = self.netG_B(img)
predict = mx.nd.squeeze(output)
predict = ((predict.transpose([1,2,0]).asnumpy() * 0.5 + 0.5) * 255).clip(0, 255).astype('uint8')
res_image = cv2.cvtColor(predict,cv2.COLOR_BGR2RGB)
result_value = {
"image_result": res_image,
"time": (time.time() - start_time) * 1000
}
return result_value
本人的代码调用比较简单
if __name__ == '__main__':
ctu = Ctu_CycleGan(USEGPU='0',image_size=256)
ctu.InitModel(DataDir_A='D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/trainA',
DataDir_B='D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/trainB',
channels=3,batch_size = 1,num_workers = 0, channels_rate=0.5)
ctu.train(TrainNum=300, learning_rate=0.0001,lr_decay_epoch='50,100,150,200',lr_decay = 0.9,ModelPath='./Model', logDir = './logs')
ctu = Ctu_CycleGan(USEGPU='0',image_size=256)
ctu.LoadModel(ModelPath=['./Model/netG_A.dat','./Model/netG_B.dat','./Model/netD_A.dat','./Model/netD_B.dat'])
cv2.namedWindow("origin", 0)
cv2.resizeWindow("origin", 640, 480)
cv2.namedWindow("result", 0)
cv2.resizeWindow("result", 640, 480)
for root, dirs, files in os.walk(r'D:/Ctu/Ctu_Project_DL/DataSet/DataSet_GAN/summer2winter_yosemite/testA'):
for f in files:
img_cv = cv2.imread(os.path.join(root, f))
if img_cv is None:
continue
res = ctu.predict(img_cv,ATOB=True)
print("耗时:" + str(res['time']) + ' ms')
cv2.imshow("origin", img_cv)
cv2.imshow("result", res['image_result'])
cv2.waitKey()