最近一部电视剧《隐秘的角落》在网上引起了众多讨论,要说这是2020年全网热度最高的电视剧也不为过。而剧中反派Boss张东升也是网友讨论的话题之一,特别是他的秃头特点,已经成为一个梗了。
突然很想知道自己秃头是什么样子,查了一下飞桨官网,果然它有图片生成的模型库。那么,我们如何使用paddlepaddle做出一个秃头生产器呢。
说到图像生成,就必须说到GAN,它是一种非监督学习的方式,通过让两个神经网络相互博弈的方法进行学习,该方法由lan Goodfellow等人在2014年提出。生成对抗网络由一个生成网络和一个判别网络组成,生成网络从潜在的空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能的分辨出来。而生成网络则尽可能的欺骗判别网络,两个网络相互对抗,不断调整参数。 生成对抗网络常用于生成以假乱真的图片。此外,该方法还被用于生成影片,三维物体模型等。
paddle模型库里用于人脸属性转换的模型主要有三种
convolution-instance norm-ReLU
组成,解码部分主要由transpose convolution-norm-ReLU
组成,判别网络主要由convolution-leaky_ReLU
组成,详细网络结构可以查看network/StarGAN_network.py
文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。
图 其他跨域模型与StarGAN模型的比较。
(a)为处理多个域,应该在每一对域都建立跨域模型。(b)StarGAN用单个generator学习多域之间的映射。该图表示连接多个域的拓扑图。
convolution-instance norm-ReLU
组成,解码部分由transpose convolution-norm-ReLU
组成,判别网络主要由convolution-leaky_ReLU
组成,详细网络结构可以查看network/AttGAN_network.py
文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。convolution-instance norm-ReLU
组成,解码网络主要由transpose convolution-norm-leaky_ReLU
组成,判别网络主要由convolution-leaky_ReLU
组成,详细网络结构可以查看network/STGAN_network.py
文件。生成网络的损失函数是由WGAN的损失函数,重构损失和分类损失组成,判别网络的损失函数由预测损失,分类损失和梯度惩罚损失组成。STGAN差不多是AttGAN的升级版,StarGAN不支持秃头属性,所以我们使用
STGAN
。
STGAN的效果图(图片属性分别为:original image, Bald, Bangs, Black Hair, Blond Hair, Brown Hair, Bushy Eyebrows, Eyeglasses, Male, Mouth Slightly Open, Mustache, No Beard, Pale Skin, Young)
本项目采用celeba数据集,关于celeba数据集的介绍,详见https://zhuanlan.zhihu.com/p/35975956
# 解压数据集
# !unzip data/data21325/imgAlignCeleba.zip -d dataset/
# !cp data/data21325/*.txt -d dataset/
# 获取模型(我已经把需要的文件放在work里,不用再获取)
# !git clone https://gitee.com/paddlepaddle/models.git -b release/1.8
# !cp -r models/PaddleCV/gan/* ./work/
# 训练,我已经花费近18个小时训练了一个,可以直接使用了,所以略过这一步吧
%cd ~/dataset
!python ../work/train.py --model_net STGAN \
--data_dir ../dataset \
--dataset . \
--crop_size 170 \
--image_size 128 \
--train_list ../dataset/attr_celeba.txt \
--gan_mode wgan \
--batch_size 32 --print_freq 1 \
--num_discriminator_time 5 \
--epoch 50 \
--dis_norm instance_norm \
--output ~/output/stgan/
# 解压训练好的模型
!unzip data/data43743/stgan.zip -d ~/
Archive: data/data43743/stgan.zip
creating: /home/aistudio/33/
creating: /home/aistudio/33/.ipynb_checkpoints/
inflating: /home/aistudio/33/net_G.pdmodel
inflating: /home/aistudio/33/net_G.pdopt
inflating: /home/aistudio/33/net_G.pdparams
在“秃头”之前,我们需要先准备要输入的图片,我把他放在my_dataset
里,并且修改dataset/test1.txt
,把图片填进去,并且根据图片的特征输入特征
%cd ~
# 输入的参数可以看看infer_bald.py开头的解释哦,主要需要注意的是n_samples、crop_size、image_size
# crop_size、image_size最好不要修改,经过我测试会影响效果
!python ./work/infer_bald.py \
--model_net STGAN \
--init_model ./33/ \
--dataset_dir my_dataset \
--test_list dataset/test1.txt \
--use_gru True \
--output ./infer_result/stgan/ \
--n_samples 1 \
--selected_attrs "Bald,Bangs,Black_Hair,Blond_Hair,Brown_Hair,Bushy_Eyebrows,Eyeglasses,Male,Mouth_Slightly_Open,Mustache,No_Beard,Pale_Skin,Young" \
--c_dim 13 \
--crop_size 178 \
--image_size 128 \
--load_height 228 \
--load_width 228 \
--crop_height 128 \
--crop_width 128 \
## 使用paddlehub
如果觉得上面的比较繁琐,infer里的代码复杂,那么有一条直接的捷径。
paddlehub里面已经有stgan的预训练模型。
# 安装paddlehub和stgan_celeba预训练模型
!pip install paddlehub==1.6.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
!hub install stgan_celeba
import paddlehub as hub
stgan = hub.Module(name="stgan_celeba")
test_img_path = ["my_dataset/img_align_celeba/000003.jpg"]
# org_info是一个只有一个元素的列表 如:["Bald,Bangs"]
# org_info要尽可能详细的说明输入图片的特征情况,否则会影响输出效果:
# 必须填写性别( "Male" 或 "Female")可选值"Bald", "Bangs",
# "Black_Hair", #"Blond_Hair", "Brown_Hair", "Bushy_Eyebrows",
# "Eyeglasses", #"Mouth_Slightly_Open", "Mustache", "No_Beard", "Pale_Skin", "Aged"
org_info = ["Male"]
# 指定要变化的特征:秃头
trans_attr = ["Bald"]
# set input dict
input_dict = {"image": test_img_path, "style": trans_attr, "info": org_info}
# execute predict and print the result
results = stgan.generate(data=input_dict)
print the result
results = stgan.generate(data=input_dict)
print(results)
https://aistudio.baidu.com/aistudio/projectdetail/620058?shared=1
最后感谢飞桨平台,让我这个初学者就能做一些有趣的试验。
还有我对stylegan挺感兴趣的,希望之后可以支持到哈
AI初学经历:
《百度架构师手把手教深度学习》:
https://aistudio.baidu.com/aistudio/education/group/info/888?activityId&shared=1
《强化学习7日打卡营》:
https://aistudio.baidu.com/aistudio/education/group/info/1335?activityId&shared=1
如果您加入官方QQ群,您将遇上大批志同道合的深度学习同学。飞桨官方QQ群:1108045677。
如果您想详细了解更多飞桨的相关内容,请参阅以下文档。
官网地址:https://www.paddlepaddle.org.cn
飞桨开源框架项目地址:
GitHub: https://github.com/PaddlePaddle/Paddle
Gitee: https://gitee.com/paddlepaddle/Paddle
飞桨生成对抗网络项目地址:
GitHub:
https://github.com/PaddlePaddle/models/tree/release/1.8/PaddleCV/gan
Gitee:
https://gitee.com/paddlepaddle/models/tree/develop/PaddleCV/gan