【创新实训】风格迁移功能探索与实现(四) train_model.py 训练代码的编写

最难的还是train_model.py的编写

其实读论文理解起来没有那么难

但是用代码实现的话就要好多天,还不能保证没有bug...

train.py中最难的还是Loss function模块的编写

因为这也是论文的创新点,没有代码可以参考


train.py中,分成四个模块,定义参数、读入数据模块、loss模块、训练模块

各模块代码及注释如下:

######################
# define the parameter#
######################
tf.app.flags.DEFINE_string('loss_model', 'vgg_16', '损失网络模型名 ')
tf.app.flags.DEFINE_string('loss_model_file', 'loss_model_ckpt/vgg_16.ckpt', '损失网络ckpt文件路径 ')
tf.app.flags.DEFINE_integer('image_size', 256, '图像大小')

# style-transfer-model的ckpt相关
tf.app.flags.DEFINE_string("model_path", "transfer_model_ckpt", "风格ckpt文件路径")
tf.app.flags.DEFINE_string("model_name", "candy", "风格名")
tf.app.flags.DEFINE_string("model_file", "models.ckpt", "风格ckpt文件名")

#内容图片与风格图片
tf.app.flags.DEFINE_string("image_file", "srcImg/test.jpg", "输入模型的图片路径")
tf.app.flags.DEFINE_string("res_file", "resImg", "模型输出的图片保存目录")
tf.app.flags.DEFINE_string("style_image", "styleImg/candy.jpg", "风格图片的路径")

#损失函数权重参数
tf.app.flags.DEFINE_float('content_weight', 1.0, '内容损失函数权重')
tf.app.flags.DEFINE_float('style_weight', 100.0, '风格损失函数权重')
tf.app.flags.DEFINE_float('tv_weight', 0.00001, 'total variation损失函数权重')

#训练数据相关参数
tf.app.flags.DEFINE_integer( 'batch_size', 128, 'batch大小')
tf.app.flags.DEFINE_integer( 'epoch', 2, 'epoch个数')

#layers
tf.app.flags.DEFINE_list("content_layers", "vgg_16/conv3/conv3_3", "用于计算内容损失的layers")
tf.app.flags.DEFINE_list("style_layers", ["vgg_16/conv1/conv1_2",
                                          "vgg_16/conv2/conv2_2"
                                          "vgg_16/conv3/conv3_3"
                                          "vgg_16/conv4/conv4_3"], "用于计算风格损失的layers")
tf.app.flags.DEFINE_string("checkpoint_exclude_scopes", "vgg_16/fc", "不从ckpt中恢复权重的层")

#learning_rate
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')

FLAGS = tf.app.flags.FLAGS

读取数据模块:

######################
# read the data #
######################
def readImage(path, height, width, preprocess_fn):
    """
    根据提供的Image路径,读取png/jpeg格式
    key fun:
    tf.read_file
    tf.image.decode_png/jpeg
    :param path: image path
    :param height:
    :param width:
    :param preprocess_fn:
    :return: a image
    """
    # 如果是png格式的图片 isPng=True
    isPng = path.lower().endswith('png')
    img_data = tf.read_file(path)
    if isPng:
        image = tf.image.decode_png(img_data, channels=3)
    else:
        tf.image.decode_jpeg(img_data, channels=3)
    return preprocess_fn(image, height, width)

def readImageBatch(batch_size, height, width, path, preprocess_fn, epochs=2, shuffle=True):
    """
    该函数为在指定路径的COCO数据集中 使用文件读取队列读一个batch_size大小的image batch,预处理后返回,用于训练
    key fun:
    tf.train.string_input_producr
    tf.WholeFileReader
    reader.read
    tf.train.batch
    :param batch_size: MINI-BATCH
    :param height: image height
    :param width:  image width
    :param path:  MS-coco数据集路径
    :param preprocess_fn: 预处理函数
    :param epochs: 迭代的epoch大小
    :param shuffle: 是否混洗
    :return: 一个image batch用于训练
    """
    filenames = [join(path, f) for f in listdir(path) if isfile(join(path, f))]
    if not shuffle:
        filenames = sorted(filenames)

    #如果是png格式的图片 isPng=True
    isPng = filenames[0].lower().endswith('png')

    filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle, num_epochs=epochs)
    # If specified, string_input_producer produces each string from string_tensor num_epochs times before generating an OutOfRange error.
    #return:A queue with the output strings. A QueueRunner for the Queue is added to the current Graph's QUEUE_RUNNER collection.
    reader = tf.WholeFileReader()
    #A Reader that outputs the entire contents of a file as a value.
    #To use, enqueue filenames in a Queue. The output of Read will be a filename (key) and the contents of that file (value).
    _, img_data = reader.read(filename_queue)
    #Decode a JPEG-encoded image to a uint8 tensor.
    if isPng:
        image = tf.image.decode_png(img_data, channels=3)
    else:
        image = tf.image.decode_jpeg(img_data, channels=3)

    processed_image = preprocess_fn(image, height, width)
    #Creates batches of tensors in tensors.
    return tf.train.batch([processed_image], batch_size, dynamic_pad=True)

定义损坏函数模块:

######################
# define the loss #
######################
def gram(layer):
    """
    计算格拉姆矩阵,计算方法参考论文
    ket method:
    tf.reshape
    :param layer: 某一激活层的输出tensor,有多张feature map
    :return:
    """
    shape = tf.shape(layer)
    num_images = shape[0]
    width = shape[1]
    height = shape[2]
    num_filters = shape[3]
    filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters]))
    grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters)

    return grams

def get_styleImg_featuremaps():
    """
    计算style-image经过VGG网络后的相应层的输出featuremaps
    该函数中会保存Target style到resImg中
    函数基本流程:
    读取style_image文件,
    预处理,输入到vgg中,
    根据FLAGS.style_layers中的层,保存相应层的endpoints,即该层生成的featuresmaps
    :return: 一个list,用于计算style-loss
    """
    with tf.Graph().as_default():
        network_fn = nets_factory.get_network_fn(
            FLAGS.loss_model,
            num_classes=1,
            is_training=False)

        image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
            FLAGS.loss_model,
            is_training=False)

        # 读取style image
        size = FLAGS.image_size
        img_bytes = tf.read_file(FLAGS.style_image)
        if FLAGS.style_image.lower().endswith('png'):
            image = tf.image.decode_png(img_bytes)
        else:
            image = tf.image.decode_jpeg(img_bytes)

        # Resize the shorter side to FLAGS.image_size
        image=image_preprocessing_fn(image, size, size)
        # 增加batch维度,因为network的输入需要一个4-D的tensor
        images = tf.expand_dims(image, 0)

        # a end_point dict.
        # key:vgg_16/conv1/conv3_3
        #计算style-image经过VGG网络后的相应层的输出features
        _, endpoints_dict = network_fn(images, spatial_squeeze=False)
        features = []
        for layer in FLAGS.style_layers:
            feature = endpoints_dict[layer]
            feature = tf.squeeze(gram(feature), [0])  # remove the batch dimension
            features.append(feature)

        with tf.Session() as sess:
            # Restore variables for loss network.
            init_func = get_init_fn()
            init_func(sess)

            if os.path.exists(FLAGS.res_file) is False:
                os.makedirs(FLAGS.res_file)
            # Indicate cropped style image path
            save_file =FLAGS.res_file+ '/target_transfer_style_' + FLAGS.model_name + '.jpg'
            # Write preprocessed style image to indicated path
            with open(save_file, 'wb') as f:
                target_image = image_unprocessing_fn(images[0, :])
                value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
                f.write(sess.run(value))
                tf.logging.info('目标风格图片保存到: %s.' % save_file)

            # Return the features those layers are use for measuring style loss.
            return sess.run(features)


def get_style_loss(endpoints_dict, style_featuremaps, style_layers):
    """
    计算style-loss
    根据style_layers提供的层,计算格拉姆矩阵差的l2范数
    :param endpoints_dict:
    :param style_featuremaps:
    :param style_layers:
    :return:style_loss
    """
    style_loss = 0
   # style_loss_summary = {}
    for style_fm, layer in zip(style_featuremaps, style_layers):
        # 每一层会有很多的feature maps  使用split进行分割
        generated_images, _ = tf.split(endpoints_dict[layer], 2, 0)
        size = tf.size(generated_images)
        #格拉姆矩阵差的l2范数
        layer_style_loss = tf.nn.l2_loss(gram(generated_images) - gram(style_fm)) * 2 / tf.to_float(size)
       # style_loss_summary[layer] = layer_style_loss
        style_loss += layer_style_loss
    return style_loss


def get_content_loss(endpoints_dict, content_layers):
    """
    根据endpoints_dict计算内容损失,损失函数的计算原理是某一个layer的输出差的l2范数
    :param endpoints_dict:由generated_images和content_images经过VGG生成的endpoints_dict
    :param content_layers:定义了需要计算content-loss的层
    :return:内容损失
    """
    content_loss = 0
    for layer in content_layers:
        generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0)
        #Splits a tensor into sub tensors.
        #把endpoints_dict[layer]沿0轴等分成两份,前一半是生成的风格images经过VGG的endpoints_dict,后一半是原始的内容image经过VGG的endpoints_dict
        size = tf.size(generated_images)
        #计算l2范数作为content-loss
        content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size)  # remain the same as in the paper
    return content_loss


def total_variation_loss(layer):
    """
    To encourage spatial smoothness in the
    output image
    :param layer:
    :return:
    """
    shape = tf.shape(layer)
    height = shape[1]
    width = shape[2]
    #Extracts a slice from a tensor.
    """
    tf.slice(
    input_,
    begin,
    size,
    name=None
)
    """
    # 错位相减  作为损失函数平滑图像
    y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
    x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
    # 需要归一化 除以size
    loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
    return loss

模型训练部分:

######################
# train the model #
######################
def get_init_fn():
    """
    恢复VGG loss model的权重
    :return: 恢复权重的函数
    """
    # 不恢复的权重列表
    exclusions = []
    if FLAGS.checkpoint_exclude_scopes:
        exclusions = [scope.strip()
                      for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
    variables_to_restore = []

    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            #根据名称来判断是否恢复
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    # 恢复vgg中需要用到的权重
    return slim.assign_from_checkpoint_fn(
        FLAGS.loss_model_file,
        variables_to_restore,
        ignore_missing_vars=True)

def main():

    # training_path 为保存style-ckpt的路径
    training_path = os.path.join(FLAGS.model_path, FLAGS.name)
    if not(os.path.exists(training_path)):
        os.makedirs(training_path)

    with tf.Graph().as_default():
        with tf.Session() as sess:
            ######得到processed_images processed_generated
            network_fn = nets_factory.get_network_fn(
                FLAGS.loss_model,
                num_classes=1,
                is_training=False)

            image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
                FLAGS.loss_model,
                is_training=False)
            processed_images = readImageBatch(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
                                            'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)

            #generated是由风格转换网络生成的image
            generated = model.base_net(processed_images, training=True)
            #Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
            #generated:[batch_size,height,width,channels]
            processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
                                   for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
                                   #The list of Tensor objects unstacked from value.
                                   ]
            processed_generated = tf.stack(processed_generated,axis=0)
            _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)

            # Log the structure of loss network

            for key in endpoints_dict:
                tf.logging.info(key)

            ########计算loss
            style_featuremaps = get_styleImg_featuremaps(FLAGS)
            content_loss = get_content_loss(endpoints_dict, FLAGS.content_layers)
            style_loss = get_style_loss(endpoints_dict, style_featuremaps, FLAGS.style_layers)
            tv_loss = total_variation_loss(generated)  # use the unprocessed image

            # 总loss,是优化目标
            loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss


            ########准备训练:恢复权重
            global_step = tf.Variable(0, name="global_step", trainable=False)

            #构建可训练变量列表
            variable_to_train = []
            for variable in tf.trainable_variables():
                if not(variable.name.startswith(FLAGS.loss_model)):
                    variable_to_train.append(variable)

            #Adam最优化算法  训练op
            train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss, global_step=global_step, var_list=variable_to_train)

            #构建要恢复的变量列表
            variables_to_restore = []
            for v in tf.global_variables():
                if not(v.name.startswith(FLAGS.loss_model)):
                    variables_to_restore.append(v)

            #构建saver,用来恢复变量
            saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)

            sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])

            # Restore variables for loss network.
            init_func = get_init_fn()
            init_func(sess)

            # Restore variables for training model if the checkpoint file exists.
            last_file = tf.train.latest_checkpoint(training_path)
            if last_file:
                #tf.logging.info('Restoring model from {}'.format(last_file))
                saver.restore(sess, last_file)

            ########开始训练
            #核心: sess.run(train_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            start_time = time.time()
            try:
                while not coord.should_stop():
                    # global_step 会自增一  跟迭代是一个概念
                    _, loss_t, step = sess.run([train_op, loss, global_step])
                    elapsed_time = time.time() - start_time
                    start_time = time.time()

                    if step % 10 == 0:
                        tf.logging.info('global_step: %d,  total_Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))

                    # checkpoint
                    if step % 1000 == 0:
                        saver.save(sess, os.path.join(training_path, FLAGS.model_name+'.ckpt'), global_step=step)
            except tf.errors.OutOfRangeError:
                saver.save(sess, os.path.join(training_path, FLAGS.model_name+'.ckpt-done'))
                tf.logging.info('完成训练')
            finally:
                coord.request_stop()
            coord.join(threads)




你可能感兴趣的:(项目创新实训)