参考链接:https://github.com/yenchenlin/pix2pix-tensorflow
https://blog.csdn.net/stdcoutzyx/article/details/78820728
utils.py
from __future__ import division
import math
import json
import random
import pprint
import scipy.misc
import numpy as np
from time import gmtime,strftime
pp = pprint.PrettyPrinter()
get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
#########################################################################
# 载入图片
# 读取图片
def imread(path,is_grayscale=False):
if(is_grayscale):
return scipy.misc.imread(path,flatten=True).astype(np.float)
else:
return scipy.misc.imread(path).astype(np.float)
# 载入图片
def load_image(image_path):
input_img = imread(image_path)
# 图片宽度
w = int(input_img.shape[1])
# 将成对数据分开
w2 = int(w/2)
img_A = input_img[:,0:w2]
img_B = input_img[:,w2:w]
# 分离label和target
return img_A,img_B
# 处理分离后的图片
def preprocess_A_and_B(img_A,img_B,load_size=286,fine_size=256,flip=True,is_test=False):
if is_test:
img_A = scipy.misc.imresize(img_A,[fine_size,fine_size])
img_B = scipy.misc.imresize(img_B,[fine_size,fine_size])
else: # 对图片做一处理,统一维度fine_size
img_A = scipy.misc.imresize(img_A,[load_size,load_size])
img_B = scipy.misc.imresize(img_B,[load_size,load_size])
h1 = int(np.ceil(np.random.uniform(1e-2,load_size-fine_size)))
w1 = int(np.ceil(np.random.uniform(1e-2,load_size-fine_size)))
img_A = img_A[h1:h1+fine_size,w1:w1+fine_size]
img_B = img_B[h1:h1+fine_size,w1:w1+fine_size]
if flip and np.random.random() > 0.5:
# 反转矩阵的左右
img_A = np.fliplr(img_A)
img_B = np.fliplr(img_B)
return img_A, img_B
# 加载数据
def load_data(image_path, flip=True,is_test=False):
# 加载图片
img_A, img_B = load_image(image_path)
# 统一维度固定大小256x256
img_A, img_B = preprocess_A_and_B(img_A, img_B, flip=flip, is_test=is_test)
# 归一化处理
img_A = img_A/127.5 - 1.
img_B = img_B/127.5 - 1.
# 按通道将A,B Concatenate起来 [fine_size,fine_size,input_c_dim + output_c_dim]->[256,256,6]
img_AB = np.concatenate((img_A,img_B),axis=2)
return img_AB
#####################################################################
# 测试
# a,b = load_image("cityscapes/train/1.jpg")
# c,d = preprocess_A_and_B(a,b)
# print(c.shape)
# a = load_data("cityscapes/train/1.jpg")
# print(a.shape)
######################################################################
# -1-1---->0-1
def inverse_transform(images):
return (images+1.)/2
# 合并图片
def merge(images,size):
h,w = images.shape[1], images.shape[2]
img = np.zeros((h*size[0], w*size[1],3))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j*h:j*h+h,i*w:i*w+w,:] = image
return img
# 保存图片
def imsave(images,size,path):
return scipy.misc.imsave(path,merge(images,size))
def save_images(images,size,image_path):
return imsave(inverse_transform(images),size,image_path)
ops.py
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
from utils import *
# 批归一化
class batch_norm(object):
def __init__(self, epsilon=1e-5, momentum=0.9, name="batch_norm"):
with tf.variable_scope(name):
self.epsilon = epsilon
self.momentum = momentum
self.name = name
def __call__(self,x,train=True):
return tf.contrib.layers.batch_norm(x,decay=self.momentum,updates_collections=None,\
epsilon=self.epsilon,scale=True,scope=self.name)
def binary_cross_entropy(preds, targets, name=None):
"""Computes binary cross entropy given `preds`.
For brevity, let `x = `, `z = targets`. The logistic loss is
loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i]))
Args:
preds: A `Tensor` of type `float32` or `float64`.
targets: A `Tensor` of the same type and shape as `preds`.
"""
eps = 1e-12
with ops.op_scope([preds, targets], name, "bce_loss") as name:
preds = ops.convert_to_tensor(preds, name="preds")
targets = ops.convert_to_tensor(targets, name="targets")
return tf.reduce_mean(-(targets * tf.log(preds + eps) +
(1. - targets) * tf.log(1. - preds + eps)))
# concat
def conv_cond_concat(x,y):
x_shapes = x.get_shape()
y_shapes = y.get_shape()
return tf.concat([x,y*tf.ones([x_shapes[0],x_shapes[1],x_shapes[2],y_shapes[3]])],3)
# 卷积
def conv2d(input_,output_dim,k_h=5,k_w=5,d_h=2,d_w=2,stddev=0.02,name="conv2d"):
with tf.variable_scope(name):
w = tf.get_variable('w',[k_h,k_w,input_.get_shape()[-1],output_dim],
initializer=tf.truncated_normal_initializer(stddev=stddev))
conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
biases = tf.get_variable('biases',[output_dim],initializer=tf.constant_initializer(0.0))
conv = tf.reshape(tf.nn.bias_add(conv,biases), conv.get_shape())
return conv
# 反卷积
def deconv2d(input_, output_shape,k_h=5,k_w=5, d_h=2, d_w=2, stddev=0.02, name="deconv2d",with_w=False):
with tf.variable_scope(name):
# 卷积核:[height, width, output_channels, in_channels]
w = tf.get_variable('w',[k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
initializer=tf.random_normal_initializer(stddev=stddev))
deconv = tf.nn.conv2d_transpose(input_,w,output_shape=output_shape,strides=[1, d_h, d_w, 1])
biases = tf.get_variable('biases',[output_shape[-1]],initializer=tf.constant_initializer(0.0))
deconv = tf.reshape(tf.nn.bias_add(deconv,biases),deconv.get_shape())
if with_w:
return deconv, w, biases
else:
return deconv
# lrelu激活函数
def lrelu(x,leak=0.2,name='lrelu'):
return tf.maximum(x,leak*x)
def linear(input_,output_size,scope=None,stddev=0.02,bias_start=0.0,with_w=False):
shape = input_.get_shape().as_list()
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix",[shape[1],output_size],tf.float32,tf.random_normal_initializer(stddev=stddev))
bias = tf.get_variable("bias",[output_size],initializer=tf.constant_initializer(bias_start))
if with_w:
return tf.matmul(input_,matrix) + bias, matrix, bias
else:
return tf.matmul(input_,matrix) + bias
model.py
from __future__ import division
import os
import time
from glob import glob
import tensorflow as tf
import numpy as np
from six.moves import xrange
from ops import *
from utils import *
class pix2pix(object):
def __init__(self, sess, image_size=256,
batch_size=1, sample_size=1, output_size=256,
gf_dim=64, df_dim=64, L1_lambda=100,
input_c_dim=3, output_c_dim=3, dataset_name='facades',
checkpoint_dir=None, sample_dir=None):
"""
Args:
sess: TensorFlow session
batch_size: The size of batch. Should be specified before training.
output_size: (optional) The resolution in pixels of the images. [256]
gf_dim: (optional) Dimension of gen filters in first conv layer. [64]
df_dim: (optional) Dimension of discrim filters in first conv layer. [64]
input_c_dim: (optional) Dimension of input image color. For grayscale input, set to 1. [3]
output_c_dim: (optional) Dimension of output image color. For grayscale input, set to 1. [3]
"""
self.sess = sess
self.is_grayscale = (input_c_dim == 1)
self.batch_size = batch_size
self.image_size = image_size
self.sample_size = sample_size
self.output_size = output_size
self.gf_dim = gf_dim
self.df_dim = df_dim
self.input_c_dim = input_c_dim
self.output_c_dim = output_c_dim
self.L1_lambda = L1_lambda
# batch normalization : deals with poor initialization helps gradient flow
self.d_bn1 = batch_norm(name='d_bn1')
self.d_bn2 = batch_norm(name='d_bn2')
self.d_bn3 = batch_norm(name='d_bn3')
self.g_bn_e2 = batch_norm(name='g_bn_e2')
self.g_bn_e3 = batch_norm(name='g_bn_e3')
self.g_bn_e4 = batch_norm(name='g_bn_e4')
self.g_bn_e5 = batch_norm(name='g_bn_e5')
self.g_bn_e6 = batch_norm(name='g_bn_e6')
self.g_bn_e7 = batch_norm(name='g_bn_e7')
self.g_bn_e8 = batch_norm(name='g_bn_e8')
self.g_bn_d1 = batch_norm(name='g_bn_d1')
self.g_bn_d2 = batch_norm(name='g_bn_d2')
self.g_bn_d3 = batch_norm(name='g_bn_d3')
self.g_bn_d4 = batch_norm(name='g_bn_d4')
self.g_bn_d5 = batch_norm(name='g_bn_d5')
self.g_bn_d6 = batch_norm(name='g_bn_d6')
self.g_bn_d7 = batch_norm(name='g_bn_d7')
self.dataset_name = dataset_name
self.checkpoint_dir = checkpoint_dir
self.build_model()
def build_model(self):
# img_A和img_Bconcat后的六通道输入
self.real_data = tf.placeholder(tf.float32,
[self.batch_size, self.image_size, self.image_size,
self.input_c_dim + self.output_c_dim],
name='real_A_and_B_images')
# 分开后的img_A和img_B
self.real_B = self.real_data[:, :, :, :self.input_c_dim]
self.real_A = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
# 输入标签图片生成目标图片
self.fake_B = self.generator(self.real_A)
# 把真的标签和目标图片concat起来
self.real_AB = tf.concat([self.real_A, self.real_B], 3)
# 把标签和生成假的目标图再concat起来
self.fake_AB = tf.concat([self.real_A, self.fake_B], 3)
# 判别器判别真假
self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False)
self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True)
# 生成器 u-net结构 生成假图
self.fake_B_sample = self.sampler(self.real_A)
# 可视化参数
self.d_sum = tf.summary.histogram("d", self.D)
self.d__sum = tf.summary.histogram("d_", self.D_)
self.fake_B_sum = tf.summary.image("fake_B", self.fake_B)
# 判别器loss
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
# 生成器loss
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_))) \
+ self.L1_lambda * tf.reduce_mean(tf.abs(self.real_B - self.fake_B))
# 可视化loss
self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real)
self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake)
self.d_loss = self.d_loss_real + self.d_loss_fake
self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'd_' in var.name]
self.g_vars = [var for var in t_vars if 'g_' in var.name]
self.saver = tf.train.Saver()
def load_random_samples(self):
# 等概率随机抽取batch_size个图片
data = np.random.choice(glob('{}/val/*.jpg'.format(self.dataset_name)), self.batch_size)
# 加载数据
sample = [load_data(sample_file) for sample_file in data]
if (self.is_grayscale):
sample_images = np.array(sample).astype(np.float32)[:, :, :, None]
else: # 变为矩阵形式,A,B已经concat后的数据[256,256,6]
sample_images = np.array(sample).astype(np.float32)
return sample_images
def sample_model(self, sample_dir, epoch, idx):
sample_images = self.load_random_samples()
# samples生成的假的图片,喂入concat后真的图片
samples, d_loss, g_loss = self.sess.run(
[self.fake_B_sample, self.d_loss, self.g_loss],
feed_dict={self.real_data: sample_images}
)
# 保存图片
save_images(samples, [self.batch_size, 1],
'./{}/train_{:02d}_{:04d}.png'.format(sample_dir, epoch, idx))
print("[Sample] d_loss: {:.8f}, g_loss: {:.8f}".format(d_loss, g_loss))
def train(self, args):
"训练pix2pix"
d_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
.minimize(self.d_loss, var_list=self.d_vars)
g_optim = tf.train.AdamOptimizer(args.lr, beta1=args.beta1) \
.minimize(self.g_loss, var_list=self.g_vars)
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
self.g_sum = tf.summary.merge([self.d__sum,
self.fake_B_sum, self.d_loss_fake_sum, self.g_loss_sum])
self.d_sum = tf.summary.merge([self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
if not os.path.exists('logs'):
os.makedirs('logs')
self.writer = tf.summary.FileWriter("./logs", self.sess.graph)
counter = 1
start_time = time.time()
if self.load(self.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
for epoch in xrange(args.epoch):
data = glob('{}/train/*.jpg'.format(self.dataset_name))
print(len(data))
#np.random.shuffle(data)
batch_idxs = min(len(data), args.train_size) // self.batch_size
for idx in xrange(0, batch_idxs):
# 文件名
batch_files = data[idx*self.batch_size:(idx+1)*self.batch_size]
# 矩阵形式数据 [256,256,6]
batch = [load_data(batch_file) for batch_file in batch_files]
if (self.is_grayscale):
batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
else:
batch_images = np.array(batch).astype(np.float32)
# 更新判别器
_, summary_str = self.sess.run([d_optim, self.d_sum],
feed_dict={ self.real_data: batch_images })
self.writer.add_summary(summary_str, counter)
# 更新生成器,运行生成器两次,确保d_loss不接近0(不同于paper)
for _ in range(2):
_,summary_str = self.sess.run([g_optim,self.g_sum],feed_dict={self.real_data:batch_images})
self.writer.add_summary(summary_str,counter)
errD_fake = self.d_loss_fake.eval({self.real_data: batch_images})
errD_real = self.d_loss_real.eval({self.real_data: batch_images})
errG = self.g_loss.eval({self.real_data: batch_images})
counter += 1
print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
% (epoch, idx, batch_idxs,
time.time() - start_time, errD_fake+errD_real, errG))
# 每100次保存一次图片
if np.mod(counter, 100) == 1:
self.sample_model(args.sample_dir, epoch, idx)
if np.mod(counter, 500) == 2:
self.save(args.checkpoint_dir, counter)
def discriminator(self,image,y=None,reuse=False):
with tf.variable_scope("discriminator") as scope:
# 图片大小为256x256x6
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse == False
h0 = lrelu(conv2d(image,self.df_dim, 5, 5, 2, 2, name='d_h0_conv'))
# h0 is (128 x 128 x self.df_dim)
h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, 5, 5, 2, 2, name='d_h1_conv')))
# h1 is (64 x 64 x self.df_dim*2)
h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, 5, 5, 2, 2, name='d_h2_conv')))
# h2 is (32 x 32 x self.df_dim*4)
h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, 5, 5, 1, 1, name='d_h3_conv')))
# h3 is (16 x 16 x self.df_dim*8)
h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h3_lin')
return tf.nn.sigmoid(h4), h4
def generator(self, image, y=None):
with tf.variable_scope("generator") as scope:
s = self.output_size
s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)
# image is (256 x 256 x input_c_dim)
e1 = conv2d(image, self.gf_dim, name='g_e1_conv')
# e1 is (128 x 128 x self.gf_dim)
e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv'))
# e2 is (64 x 64 x self.gf_dim*2)
e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv'))
# e3 is (32 x 32 x self.gf_dim*4)
e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv'))
# e4 is (16 x 16 x self.gf_dim*8)
e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv'))
# e5 is (8 x 8 x self.gf_dim*8)
e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv'))
# e6 is (4 x 4 x self.gf_dim*8)
e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv'))
# e7 is (2 x 2 x self.gf_dim*8)
e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv'))
# e8 is (1 x 1 x self.gf_dim*8)
self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),
[self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True)
d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5)
d1 = tf.concat([d1, e7], 3)
# d1 is (2 x 2 x self.gf_dim*8*2)
self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),
[self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True)
d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5)
d2 = tf.concat([d2, e6], 3)
# d2 is (4 x 4 x self.gf_dim*8*2)
self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),
[self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True)
d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5)
d3 = tf.concat([d3, e5], 3)
# d3 is (8 x 8 x self.gf_dim*8*2)
self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),
[self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True)
d4 = self.g_bn_d4(self.d4)
d4 = tf.concat([d4, e4], 3)
# d4 is (16 x 16 x self.gf_dim*8*2)
self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),
[self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True)
d5 = self.g_bn_d5(self.d5)
d5 = tf.concat([d5, e3], 3)
# d5 is (32 x 32 x self.gf_dim*4*2)
self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),
[self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True)
d6 = self.g_bn_d6(self.d6)
d6 = tf.concat([d6, e2], 3)
# d6 is (64 x 64 x self.gf_dim*2*2)
self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),
[self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True)
d7 = self.g_bn_d7(self.d7)
d7 = tf.concat([d7, e1], 3)
# d7 is (128 x 128 x self.gf_dim*1*2)
self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),
[self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True)
# d8 is (256 x 256 x output_c_dim)
return tf.nn.tanh(self.d8)
def sampler(self, image, y=None):
with tf.variable_scope("generator") as scope:
scope.reuse_variables()
s = self.output_size
s2, s4, s8, s16, s32, s64, s128 = int(s/2), int(s/4), int(s/8), int(s/16), int(s/32), int(s/64), int(s/128)
# image is (256 x 256 x input_c_dim)
e1 = conv2d(image, self.gf_dim, name='g_e1_conv')
# e1 is (128 x 128 x self.gf_dim)
e2 = self.g_bn_e2(conv2d(lrelu(e1), self.gf_dim*2, name='g_e2_conv'))
# e2 is (64 x 64 x self.gf_dim*2)
e3 = self.g_bn_e3(conv2d(lrelu(e2), self.gf_dim*4, name='g_e3_conv'))
# e3 is (32 x 32 x self.gf_dim*4)
e4 = self.g_bn_e4(conv2d(lrelu(e3), self.gf_dim*8, name='g_e4_conv'))
# e4 is (16 x 16 x self.gf_dim*8)
e5 = self.g_bn_e5(conv2d(lrelu(e4), self.gf_dim*8, name='g_e5_conv'))
# e5 is (8 x 8 x self.gf_dim*8)
e6 = self.g_bn_e6(conv2d(lrelu(e5), self.gf_dim*8, name='g_e6_conv'))
# e6 is (4 x 4 x self.gf_dim*8)
e7 = self.g_bn_e7(conv2d(lrelu(e6), self.gf_dim*8, name='g_e7_conv'))
# e7 is (2 x 2 x self.gf_dim*8)
e8 = self.g_bn_e8(conv2d(lrelu(e7), self.gf_dim*8, name='g_e8_conv'))
# e8 is (1 x 1 x self.gf_dim*8)
self.d1, self.d1_w, self.d1_b = deconv2d(tf.nn.relu(e8),
[self.batch_size, s128, s128, self.gf_dim*8], name='g_d1', with_w=True)
d1 = tf.nn.dropout(self.g_bn_d1(self.d1), 0.5)
d1 = tf.concat([d1, e7], 3)
# d1 is (2 x 2 x self.gf_dim*8*2)
self.d2, self.d2_w, self.d2_b = deconv2d(tf.nn.relu(d1),
[self.batch_size, s64, s64, self.gf_dim*8], name='g_d2', with_w=True)
d2 = tf.nn.dropout(self.g_bn_d2(self.d2), 0.5)
d2 = tf.concat([d2, e6], 3)
# d2 is (4 x 4 x self.gf_dim*8*2)
self.d3, self.d3_w, self.d3_b = deconv2d(tf.nn.relu(d2),
[self.batch_size, s32, s32, self.gf_dim*8], name='g_d3', with_w=True)
d3 = tf.nn.dropout(self.g_bn_d3(self.d3), 0.5)
d3 = tf.concat([d3, e5], 3)
# d3 is (8 x 8 x self.gf_dim*8*2)
self.d4, self.d4_w, self.d4_b = deconv2d(tf.nn.relu(d3),
[self.batch_size, s16, s16, self.gf_dim*8], name='g_d4', with_w=True)
d4 = self.g_bn_d4(self.d4)
d4 = tf.concat([d4, e4], 3)
# d4 is (16 x 16 x self.gf_dim*8*2)
self.d5, self.d5_w, self.d5_b = deconv2d(tf.nn.relu(d4),
[self.batch_size, s8, s8, self.gf_dim*4], name='g_d5', with_w=True)
d5 = self.g_bn_d5(self.d5)
d5 = tf.concat([d5, e3], 3)
# d5 is (32 x 32 x self.gf_dim*4*2)
self.d6, self.d6_w, self.d6_b = deconv2d(tf.nn.relu(d5),
[self.batch_size, s4, s4, self.gf_dim*2], name='g_d6', with_w=True)
d6 = self.g_bn_d6(self.d6)
d6 = tf.concat([d6, e2], 3)
# d6 is (64 x 64 x self.gf_dim*2*2)
self.d7, self.d7_w, self.d7_b = deconv2d(tf.nn.relu(d6),
[self.batch_size, s2, s2, self.gf_dim], name='g_d7', with_w=True)
d7 = self.g_bn_d7(self.d7)
d7 = tf.concat([d7, e1], 3)
# d7 is (128 x 128 x self.gf_dim*1*2)
self.d8, self.d8_w, self.d8_b = deconv2d(tf.nn.relu(d7),
[self.batch_size, s, s, self.output_c_dim], name='g_d8', with_w=True)
# d8 is (256 x 256 x output_c_dim)
return tf.nn.tanh(self.d8)
# 保存模型
def save(self, checkpoint_dir, step):
model_name = "pix2pix.model"
model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,
os.path.join(checkpoint_dir, model_name),
global_step=step)
# 加载模型
def load(self, checkpoint_dir):
print(" [*] Reading checkpoint...")
model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
return True
else:
return False
def test(self, args):
"""Test pix2pix"""
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
sample_files = glob('{}\\val\\*.jpg'.format(self.dataset_name))
print(sample_files)
# sort testing input
n = [int(i) for i in map(lambda x: x.split('\\')[-1].split('.jpg')[0], sample_files)]
sample_files = [x for (y, x) in sorted(zip(n, sample_files))]
# load testing input
print("Loading testing images ...")
sample = [load_data(sample_file, is_test=True) for sample_file in sample_files]
if (self.is_grayscale):
sample_images = np.array(sample).astype(np.float32)[:, :, :, None]
else:
sample_images = np.array(sample).astype(np.float32)
sample_images = [sample_images[i:i+self.batch_size]
for i in xrange(0, len(sample_images), self.batch_size)]
sample_images = np.array(sample_images)
print(sample_images.shape)
start_time = time.time()
if self.load(self.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
for i, sample_image in enumerate(sample_images):
idx = i+1
print("sampling image ", idx)
samples = self.sess.run(
self.fake_B_sample,
feed_dict={self.real_data: sample_image}
)
save_images(samples, [self.batch_size, 1],
'./{}/test_{:04d}.png'.format(args.test_dir, idx))
main.py
import argparse
import os
import scipy.misc
import numpy as np
from model import pix2pix
import tensorflow as tf
parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset_name', dest='dataset_name', default='cityscapes', help='name of the dataset')
parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train')
parser.add_argument('--load_size', dest='load_size', type=int, default=256, help='scale images to this size')
parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer')
parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels')
parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels')
parser.add_argument('--niter', dest='niter', type=int, default=200, help='# of iter at starting learning rate')
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--flip', dest='flip', type=bool, default=True, help='if flip the images for data argumentation')
parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA')
parser.add_argument('--phase', dest='phase', default='train', help='train, test')
parser.add_argument('--save_epoch_freq', dest='save_epoch_freq', type=int, default=1000, help='save a model every save_epoch_freq epochs (does not overwrite previously saved models)')
parser.add_argument('--save_latest_freq', dest='save_latest_freq', type=int, default=5000, help='save the latest model every latest_freq sgd iterations (overwrites the previous latest model)')
parser.add_argument('--print_freq', dest='print_freq', type=int, default=10, help='print the debug information every print_freq iterations')
parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false')
parser.add_argument('--serial_batches', dest='serial_batches', type=bool, default=False, help='f 1, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--serial_batch_iter', dest='serial_batch_iter', type=bool, default=True, help='iter into serial image list')
parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')
parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')
parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=100.0, help='weight on L1 term in objective')
args = parser.parse_args()
def main(_):
if not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
if not os.path.exists(args.sample_dir):
os.makedirs(args.sample_dir)
if not os.path.exists(args.test_dir):
os.makedirs(args.test_dir)
with tf.Session() as sess:
model = pix2pix(sess, image_size=args.fine_size, batch_size=args.batch_size,
output_size=args.fine_size, dataset_name=args.dataset_name,
checkpoint_dir=args.checkpoint_dir, sample_dir=args.sample_dir)
if args.phase == 'train':
model.train(args)
else:
model.test(args)
if __name__ == '__main__':
tf.app.run()
自己训练结果