基于Tensorflow和DCGAN生成动漫头像实践(一)

前言:学习tensorflow和深度学习有一段时间了,一直停留在运行别人的代码和跑mnsit和cifar10数据集上,决定从简单的动漫头像生成着手代码,经过无数的debug后终于完成大概,此间主要参考的有以下两个代码,一个是别人写的DCGAN动漫头像生成,另一个是pix2pix的tensorflow实现代码。

动漫头像生成:https://blog.csdn.net/sinat_33741547/article/details/77871170?locationNum=5&fps=1阿城

pix2pix代码:https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py


说明:本部分是数据是数据处理部分,采用的数据是别人提取好的动漫头像,共50000多张,将这些图片转化为tensorflow官方的标准数据TFrecord格式,这个格式的在tensorflow处理的时侯读取速度会快不少

数据来源

百度网盘  密码:g5qa


代码

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
'''
     读取图片数据并转化为tensorflow官方的TFrecord格式
'''
import tensorflow as tf
import os
import sys
import time


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def get_TF():
    train_dir = "./faces/"    #定义读取图片的路径
    data  = []  
    for file in os.listdir(train_dir):   #将图片的路径存储到data list中
        data.append(train_dir+file)

    
    stdi,stdo,stde=sys.stdin,sys.stdout,sys.stderr  #如果没有这部分会提示编码错误
    reload(sys)                                     #python3的reload在其他包中
    sys.setdefaultencoding('utf-8') 
    sys.stdin,sys.stdout,sys.stderr=stdi,stdo,stde  #改正reload之后print输出不了的问题
  
    sess=tf.Session()
    file_at = 0
    start_time = time.time()
    for i in range(len(data)):
        
        image_path = data[i]    #枚举每个图片的路径
        image_raw_data = tf.gfile.FastGFile(image_path,'r').read()
        img_data = tf.image.decode_jpeg(image_raw_data,channels=3)    #将读取到的图片按照jpeg的格式解压成tensor的形式            
        img_data = img_data.eval(session=sess)
        image_raw = img_data.tobytes()    #将图片的tensor变成字符串

        example = tf.train.Example(features=tf.train.Features(feature={    #构造TFrecord形式的example
                'height':_int64_feature(img_data.shape[0]),
                'width':_int64_feature(img_data.shape[1]),
                'channel':_int64_feature(img_data.shape[2]),
                'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))   #之后需要的只有'image_raw',其他可以不定义
                }))
        if i % 500 == 0:   #500个example存储为一个TFrecord文件
            file_at += 1
            filename = ("./TFrecord/data-tfrecords-%.5d" % file_at)
            if i>0:
                writer.close()
            writer = tf.python_io.TFRecordWriter(filename)
            print("%d steps,using time %f" % (i,time.time()-start_time))
            start_time =time.time()
        writer.write(example.SerializeToString()) #将examples写入TFrecord文件

 

    writer.close()
   
    
    
get_TF()
在程序实际运行的时候,一开始处理很快,但是后来生成一个TFrecord文件就越运行越慢,查了资料没发现其他人有出现这个问题,没有解决。当然,也可以直接读取原图片训练。

你可能感兴趣的:(深度学习,GAN)