paper:https://sites.google.com/site/xipengcshomepage/research/ijcai18
code:https://github.com/bluer555/CR-GAN
在开始前我先提醒各位进行的小伙伴,本篇文章的训练所需要的数据集相当大(10G有多吧),所以你们需要开足才能去下载这个庞大的数据集,不然就不要训练了。
CR-GAN是用GAN(对抗网络)来进行人脸多角度的图片生成的文章,相对于先前的BiGAN、DR-GAN、TP-GAN,CR-GAN在网络上做了一个改进通过采用双支网络已经共享网络参数的形式进行训练,本篇博客仅对该文章的训练做一个详细的说明以及步骤的说明。
训练环境的搭建:
1. Python 2.7
2. Pytorch 0.3.1
请根据自己电脑cuda的版本进行下载,不然就会出问题,本人系统为linux,cuda为8.0,python 为2.7,所以一下给出的安装指令以及是该环境下的0.3.1版本,千万别弄错了,不然会报错误或警告。
:~$ pip install http://download.pytorch.org/whl/cu80/torch-0.3.1-cp27-cp27mu-linux_x86_64.whl
相关配包教程:https://ptorch.com/news/145.html
官网安装:https://pytorch.org/previous-versions/
训练的步骤:
1. 下载预训练模型;
Google驱动下载:https://drive.google.com/open?id=1J3VffWKe8akdiNM2hy7NI3lY4xM_xL-c
百度网盘下载:
1.https://pan.baidu.com/s/1Bc_Ipkz22Q28McfjH7thOQ 密码:ac63
2.https://pan.baidu.com/s/1DvCWRbgOJQpjaPV8J4lZIA 密码:avwe
3.https://pan.baidu.com/s/1391QFBo4wL7xZhiu4fWYyQ 密码:13zb
2. 下载训练的数据库;
(由于数据库相对较大,我就不上传了,实在没办法,请自行下载吧)
将下载好的数据存放入你的工程项目CR-GAN文件夹的data路径下,没有就自行创建一个。
数据库1: https://drive.google.com/open?id=1QxNCh6vfNSZkod1Rg_zHLI1FM8WyXix4
300w-LP 数据库: http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm
数据库2: https://drive.google.com/open?id=1DD6AO9Y5rAgiiW7IJY2kBxI_bCcfhYo4
300w-LP(作用于dataload.py中的txt文件): https://drive.google.com/open?id=1TIfcpn4N3rgGlzWl0lXNZKhy7XWVWOoA
3. 从git中下载源码;
:~$ git clone https://github.com/bluer555/CR-GAN
4. 在源码中修改相关读取文件的路径;
(1)在train.py文件
# 训练所需要图片的路径的列表
parser.add_argument("-d", "--data_list", type=str, default="./list_test.txt")
# 训练模型保存的路径
parser.add_argument('--outf', default='./evaluate', help='folder to output images and model checkpoints')
# 预训练模型的路径
parser.add_argument('--modelf', default='./evaluate_model', help='folder to input images and model checkpoints')
(2)在data_loader.py文件
def get_multiPIE_img(img_path):
.....
img2_path = '/porject-path/data/multi_PIE_crop_128/' + ID + '/' + ID + '_01_' + status + '_' + view + '_' + bright + '_crop_128.png'
.....
这个地方需要修改你工程项目的路径。
5. 在项目的当前目录下输入:
:~$ cd CRGAN
:~$ python train.py
6. 当你开始训练时会遇到一些问题,请不要紧张,这个问题不是大问题:
问题1:...data_parallel.py:24: UserWarning: .......... warnings.warn(imbalance_warn.format(...))
解析:这个问题是你安装的pytorch版本不是0.3.1所造成的
问题2:...model.py:115: UserWarning: Implicit dimensiion choice ..........Change the call to include dim=x.... v = self.softmax(v)....
解析:这个问题是由于作者先前写项目是给予pytorch较低的版本写的,所以会出现这个警告,不用管。
训练的全过程到此结束,如果要测试请见下一篇博客。
以上是本作者为了解决问题所提出的建议,不见得都能适用,但是基本能够解决问题,如果有什么不正确的地方,请在下方评论区留言,本文章仅代表作者本人意见!