使用标签的数据集应用于生成对抗网络可以增强现有的生成模型,并形成两种优化思路。
ACGAN目标函数:
对于生成器来说有两个输入,一个是标签的分类数据c,另一个是随机数据z,得到生成数据为 ; 对于判别器分别要判断数据源是否为真实数据的概率分布 ,以及数据源对于分类标签的概率分布
ACGAN的目标函数包含两部分: 第一部分 是面向数据真实与否的代价函数 第二部分 则是数据分类准确性的代价函数。
在优化过程中希望判别器D能否使得 + 尽可能最大,而生成器G使得 - 尽可能最大; 简而言之是希望判别器能够尽可能区分真实数据和生成数据并且能有效对数据进行分类,对生成器来说希望生成数据被尽可能认为是真实数据且数据都能够被有效分类。
参考论文:Conditional Image Synthesis with Auxiliary Classifier GANs
ModelArts 是面向开发者的一站式 AI 开发平台,为机器学习与深度学习提供海量数据预处理及交互式智能标注、大规模分布式训练、自动化模型生成,及端-边-云模型按需部署能力,帮助用户快速创建和部署模型,管理全周期 AI 工作流。
下图就是ModelArts的能力图:
首先进入到AI Gallery平台,找到ACGAN实践案例
点击“Run in ModelArts”进入ModelArts的JupyterLab工作空间,如下图所示:
如果没有登录,则会跳转到登录页面,登录后会出现连接中的提示,等待约10s后可进入到空间中
由于本次体验需要使用的硬件为GPU版本,因此需要进行运行环境的切换
点击切换规格按钮,选择限时免费规格GPU,点击切换规格即可
等待切换完成:
我们可以执行:Run All Cells来快速体验案例,如下图:
点击按钮后会自动执行所有的命令,最终结果如下图所示:
体验过程很快,但是具体每个命令执行了什么内容,我们来一一分析一下:
1)下载模型和代码
import os
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/ACGAN.zip
# 解压缩
os.system('unzip ACGAN.zip -d ./')
2)模型训练:加载依赖库
root_path = './ACGAN/'
os.chdir(root_path)
import os
from main import main
from ACGAN import ACGAN
from tools import checkFolder
import tensorflow as tf
import argparse
import numpy as np
在这个步骤中我们加载了ACGAN库,用于后续训练用
3)模型训练:设置参数
def parse_args():
note = "ACGAN Frame Constructed With Tensorflow"
parser = argparse.ArgumentParser(description=note)
parser.add_argument("--epoch",type=int,default=251,help="训练轮数")
parser.add_argument("--batchSize",type=int,default=64,help="batch的大小")
parser.add_argument("--codeSize",type=int,default=62,help="输入编码向量的维度")
parser.add_argument("--checkpointDir",type=str,default="./checkpoint",help="检查点保存目录")
parser.add_argument("--resultDir",type=str,default="./result",help="训练过程中,中间生成结果的目录")
parser.add_argument("--logDir",type=str,default="./log",help="训练日志目录")
parser.add_argument("--mode",type=str,default="train",help="模式: train / infer")
parser.add_argument("--hairStyle",type=str,default="orange hair",help="你想要生成的动漫头像的头发颜色")
parser.add_argument("--eyeStyle",type=str,default="gray eyes",help="你想要生成的动漫头像的眼睛颜色")
parser.add_argument("--dataSource",type=str,default='./extra_data/images/',help="训练集路径")
args, unknown= parser.parse_known_args()
checkFolder(args.checkpointDir)
checkFolder(args.resultDir)
checkFolder(args.logDir)
assert args.epoch>=1
assert args.batchSize>=1
assert args.codeSize>=1
return args
args =parse_args()
此步骤中我们主要设置训练轮数、batch大小、头发眼睛颜色、训练日志等参数,可根据需要进行调整
4)模型训练:开始训练
with tf.Session() as sess :
myGAN = ACGAN(sess,args.epoch,args.batchSize,args.codeSize,\
args.dataSource,args.checkpointDir,args.resultDir,args.logDir,args.mode,\
64,64,3)
if myGAN is None:
print("创建GAN网络失败")
exit(0)
if args.mode=='train' :
myGAN.buildNet()
print("进入训练模式")
myGAN.train()
print("Done")
5)测试模型
首先修改参数从训练模式为推理模式
args.mode ='infer'
从标签里选择你想要生成的头像的头发和眼睛,只能从这两个列表里选择
hair_dict = ['orange hair', 'white hair', 'aqua hair', 'gray hair', 'green hair', 'red hair', 'purple hair',
'pink hair', 'blue hair', 'black hair', 'brown hair', 'blonde hair']
eye_dict = [ 'gray eyes', 'black eyes', 'orange eyes', 'pink eyes', 'yellow eyes',
'aqua eyes', 'purple eyes', 'green eyes', 'brown eyes', 'red eyes', 'blue eyes']
# 选择了黄头发和灰眼睛
args.hairStyle = 'orange hair'
args.eyeStyle = 'gray eyes'
构造预测器:
tf.reset_default_graph()
with tf.Session() as sess :
myGAN1 = ACGAN(sess,args.epoch,args.batchSize,args.codeSize,\
args.dataSource,args.checkpointDir,args.resultDir,args.logDir,args.mode,\
64,64,3)
if myGAN1 is None:
print("创建GAN网络失败")
exit(0)
if args.mode=='infer' :
myGAN1.buildForInfer()
tag_dict = ['orange hair', 'white hair', 'aqua hair', 'gray hair', 'green hair', 'red hair', 'purple hair', 'pink hair', 'blue hair', 'black hair',
'brown hair', 'blonde hair','gray eyes', 'black eyes', 'orange eyes', 'pink eyes', 'yellow eyes','aqua eyes', 'purple eyes', 'green eyes',
'brown eyes', 'red eyes','blue eyes']
tag = np.zeros((64,23))
feature = args.hairStyle+" AND "+ args.eyeStyle
for j in range(25):
for i in range(len(tag_dict)):
if tag_dict[i] in feature:
tag[j][i] = 1
myGAN1.infer(tag,feature)
print("Generate : "+feature)
开始生成黄色头发,灰色眼睛的动漫头像
PS:可能存在生成不了正确头像的情况
import matplotlib.pyplot as plt
from PIL import Image
feature = args.hairStyle+" AND "+ args.eyeStyle
resultPath = './samples/' + feature + '.png' #确定保存路径
img = Image.open(resultPath).convert('RGB')
plt.figure(1)
plt.imshow(img)
plt.show()
上述代码执行后即可生成动漫头像了。
通过本次体验,可以快速的了解ModelArts的一些操作逻辑,并且对ACGAN函数的应用场景有了一定的了解。
在体验过程中也会遇到一些问题,比如在加载依赖库的时候有时会出现一些告警提示,或者如果执行顺序不小心搞错导致提示文件或者目录不存在之类的错误,因此在进行模型训练的过程中,相关的命令执行一定要非常的仔细,否则容易出错。
另外这次只是一个很基础的体验,对于模型训练过程中所需的一些基础原理、基础数据、标签等还没有深入研究,期待有机会再做深入了解,也期待有志之士一起探讨学习。