GoogLeNet tensorflow完整复现

 代码来源《深度学习:卷积神经网络从入门到精通》,使用oxflower—17数据集

# train.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os.path
import time
import numpy as np
from six.moves import xrange
import tensorflow as tf
import data_loader
import arch
import sys
import argparse


def loss(logits, labels):  # 定义损失函数
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits,
                                                                   name='cross_entropy_per_example')
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
    tf.summary.scalar('Cross Entropy Loss', cross_entropy_mean)  # 数据的汇总和记录
    return cross_entropy_mean


def average_gradients(tower_grads):  # 定义平均梯度函数
    average_grads = []
    for grad_and_vars in zip(*tower_grads):  # zip函数可接受任意多个序列为参数,返回tuple列表
        grads = []
        for g, _ in grad_and_vars:
            expanded_g = tf.expand_dims(g, 0)  # 扩展维度
            grads.append(expanded_g)
        grad = tf.concat(axis=0, values=grads)
        grad = tf.reduce_mean(grad, 0)
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads


def train(args):  # 定义训练过程
    with tf.device('/cpu:0'):
        images, labels = data_loader.read_inputs(True, args)
        epoch_number = tf.get_variable('epoch_number', [], dtype=tf.int32,
                                       initializer=tf.constant_initializer(0), trainable=False)
        lr = tf.train.piecewise_constant(epoch_number, [19, 30, 44, 53],
                                         [0.01, 0.005, 0.001, 0.0005, 0.0001], name='LearningRate')
        wd = tf.train.piecewise_constant(epoch_number, [30], [0.0005, 0.0],
                                         name='WeightDecay')
        opt = tf.train.MomentumOptimizer(lr, 0.9)  # 使用动量优化方法
        tower_grads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in xrange(args.num_gpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('Tower_%d' % i) as scope:
                        logits = arch.get_model(images, wd, True, args)
                        top1acc = tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits,
                                                                        labels, 1), tf.float32))  # top-1准确率
                        top5acc = tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits,
                                                                        labels, 5), tf.float32))  # top-5准确率
                        cross_entropy_mean = loss(logits, labels)
                        regularization_losses = tf.get_collection(tf.
                                                                  GraphKeys.REGULARIZATION_LOSSES)
                        reg_loss = tf.add_n(regularization_losses)
                        # 对应位置元素相加
                        tf.summary.scalar('Regularization Loss', reg_loss)
                        # 对reg_loss标量汇总和记录
                        total_loss = tf.add(cross_entropy_mean, reg_loss)
                        tf.summary.scalar('Total Loss', total_loss)
                        # 对total_loss标量汇总和记录
                        tf.summary.scalar('Top-1 Accuracy', top1acc)
                        # 对top1acc标量汇总和记录
                        tf.summary.scalar('Top-5 Accuracy', top5acc)
                        # 对top5acc标量汇总和记录
                        tf.get_variable_scope().reuse_variables()
                        # 表示允许重用当前scope下所有变量
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
                        batchnorm_updates = tf.get_collection(tf.GraphKeys.
                                                              UPDATE_OPS, scope)
                        grads = opt.compute_gradients(total_loss)
                        # 按批计算数据的梯度
                        tower_grads.append(grads)
    grads = average_gradients(tower_grads)
    summaries.append(tf.summary.scalar('learning_rate', lr))
    summaries.append(tf.summary.scalar('weight_decay', wd))
    apply_gradient_op = opt.apply_gradients(grads)  # 更新模型的权值参数
    batchnorm_updates_op = tf.group(*batchnorm_updates)  # 更新BN层的参数
    train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=args.num_epochs)
    summary_op = tf.summary.merge_all()
    init = tf.global_variables_initializer()
    if args.log_debug_info:
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()
    else:
        run_options = None
        run_metadata = None
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=args.log_device_placement))
    if args.retrain_from is not None:
        saver.restore(sess, args.retrain_from)
    else:
        sess.run(init)
    tf.train.start_queue_runners(sess=sess)  # 启动输入管道的线程
    summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)
    start_epoch = sess.run(epoch_number + 1)
    for epoch in range(start_epoch, start_epoch + args.num_epochs):
        sess.run(epoch_number.assign(epoch))
        for step in range(args.num_batches):
            start_time = time.time()
            _, loss_value, top1_accuracy, top5_accuracy = sess.run([train_op,
                                                                    cross_entropy_mean,
                                                                    top1acc, top5acc],
                                                                   options=run_options,
                                                                   run_metadata=run_metadata)
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            if step % 10 == 0:
                num_examples_per_step = args.chunked_batch_size * args.num_gpus
                examples_per_sec = num_examples_per_step / duration
                sec_per_batch = duration / args.num_gpus
                format_str = (
                    '%s: epoch %d, step %d, loss = %.2f, Top-1 = %.2f Top-5 = %.2f (%.1f examples/sec; %.3f sec/batch)')
                print(format_str % (datetime.now(), epoch, step, loss_value,
                                    top1_accuracy, top5_accuracy,
                                    examples_per_sec, sec_per_batch))
                sys.stdout.flush()  # 等到程序执行完毕在屏幕上一次性输出结果
            if step % 100 == 0:
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, args.num_batches * epoch + step)
                # 写入文件
                if args.log_debug_info:
                    summary_writer.add_run_metadata(run_metadata, 'epoch%d step%d' % (epoch, step))
    checkpoint_path = os.path.join(args.log_dir, args.snapshot_prefix)  # 组合多个路径
    saver.save(sess, checkpoint_path, global_step=epoch)


def main():  # 定义训练主函数
    parser = argparse.ArgumentParser(description='Process Command-line Arguments')
    # 创建解析器
    parser.add_argument('--load_size', nargs=2, default=[256, 256], type=int,
                        action='store',
                        help='The width and height of images for loading from disk')
    # 添加命令行参数和选项
    parser.add_argument('--crop_size', nargs=2, default=[224, 224], type=int,
                        action='store',
                        help='The width and height of images after random cropping')
    parser.add_argument('--batch_size', default=12, type=int, action='store',
                        help='The training batch size')
    parser.add_argument('--num_classes', default=17, type=int, action='store',
                        help='The number of classes')
    parser.add_argument('--num_channels', default=3, type=int, action='store',
                        help='The number of channels in input images')
    parser.add_argument('--num_epochs', default=55, type=int, action='store',
                        help='The number of epochs')
    parser.add_argument('--path_prefix', default='./', action='store',
                        help='the prefix address for images')
    parser.add_argument('--data_info', default='train.txt', action='store',
                        help='Name of the file containing addresses and labels of training images')
    parser.add_argument('--shuffle', default=True, type=bool, action='store',
                        help='Shuffle training data or not')
    parser.add_argument('--num_threads', default=20, type=int, action='store',
                        help='The number of threads for loading data')
    parser.add_argument('--log_dir', default=None, action='store',
                        help='Path for saving Tensorboard info and checkpoints')
    parser.add_argument('--snapshot_prefix', default='snapshot', action='store',
                        help='Prefix for checkpoint files')
    parser.add_argument('--architecture', default='googlenet', help='The DNN architecture')
    parser.add_argument('--depth', default=40, type=int, action='store',
                        help='The depth of ResNet architecture')
    parser.add_argument('--run_name',
                        default='Run' + str(time.strftime("-%d-%m-%Y_%H-%M-%S")), action='store',
                        help='Name of the experiment')
    parser.add_argument('--num_gpus', default=1, type=int, action='store',
                        help='Number of GPUs')
    parser.add_argument('--log_device_placement', default=False, type=bool,
                        help='Whether to log device placement or not')
    parser.add_argument('--delimiter', default=' ', action='store',
                        help='Delimiter of the input files')
    parser.add_argument('--retrain_from', default=None, action='store',
                        help='Continue Training from a snapshot file')
    parser.add_argument('--log_debug_info', default=False, action='store',
                        help='Logging runtime and memory usage info')
    parser.add_argument('--num_batches', default=-1, type=int, action='store',
                        help='The number of batches per epoch')
    args = parser.parse_args()  # 解析参数列表
    args.chunked_batch_size = int(args.batch_size / args.num_gpus)
    args.num_samples = sum(1 for line in open(args.data_info))
    if args.num_batches == -1:
        args.num_batches = int(args.num_samples / args.batch_size) + 1
    if args.log_dir is None:
        args.log_dir = args.architecture + "_" + args.run_name
    print(args)
    print("Saving everything in " + args.log_dir)
    if tf.gfile.Exists(args.log_dir):
        tf.gfile.DeleteRecursively(args.log_dir)  # 递归删除目录下的所有文件
    tf.gfile.MakeDirs(args.log_dir)  # 在args.log_dir目录下创建文件夹
    train(args)


if __name__ == '__main__':
    main()
# data_loader.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf


def _read_label_file(file, delimiter):  # 读入标签和文件
    f = open(file, "r")  # 以读入形式打开文件
    filepaths = []
    labels = []
    for line in f:
        tokens = line.split(delimiter)  # 以分隔符将line分为两部分并赋值给tokens
        filepaths.append(tokens[0])  # 将tokens[0]添加到filepaths中
        labels.append(int(tokens[1]))  # 将int(tokens[1])添加到labels中
    return filepaths, labels


def read_inputs(is_training, args):  # 读入数据函数
    filepaths, labels = _read_label_file(args.data_info, args.delimiter)
    filenames = [os.path.join(args.path_prefix, i) for i in filepaths]
    # 将路径和文件合并
    if is_training:  # 创建一个文件读取的队列
        filename_queue = tf.train.slice_input_producer([filenames, labels], shuffle= \
            args.shuffle, capacity=1024)
    else:
        filename_queue = tf.train.slice_input_producer([filenames, labels], shuffle= \
            False,
                                                       capacity=1024, num_epochs=1)
    file_content = tf.read_file(filename_queue[0])  # 读取队列
    reshaped_image = tf.to_float(tf.image.decode_jpeg(file_content, channels=args. \
                                                      num_channels))
    reshaped_image = tf.image.resize_images(reshaped_image, args.load_size)
    # 重置图像大小
    label = tf.cast(filename_queue[1], tf.int64)  # 转换数据类型
    img_info = filename_queue[0]
    if is_training:
        reshaped_image = _train_preprocess(reshaped_image, args)  # 训练预处理
    else:
        reshaped_image = _test_preprocess(reshaped_image, args)  # 测试预处理
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(5000*min_fraction_of_examples_in_queue)
    print('Filling queue with %d images before starting to train.' 'This may '
          'take some times.' % min_queue_examples)
    batch_size = args.chunked_batch_size if is_training else args.batch_size
    # 加载图像和标签的附加信息
    if hasattr(args, 'save_predictions') and args.save_predictions is not None:  # 判断对象的name属性或方法
        images, label_batch, info = tf.train.batch([reshaped_image, label,
                                                    img_info], batch_size=batch_size, num_threads=args.num_threads,
                                                   capacity=min_queue_examples + 3*batch_size,
                                                                                   allow_smaller_final_batch = True if not is_training else False)
        return images, label_batch, info
    else:
        images, label_batch = tf.train.batch([reshaped_image, label],
                                             batch_size = batch_size,
                                             allow_smaller_final_batch = True if not is_training else False,
                                             num_threads = args.num_threads,
                                             capacity = min_queue_examples + 3 * batch_size)
        return images, label_batch


def _train_preprocess(reshaped_image, args):  # 定义训练时的预处理函数
    reshaped_image = tf.random_crop(reshaped_image, [args.crop_size[0], args. \
                                    crop_size[1], args.num_channels])
    reshaped_image = tf.image.random_flip_left_right(reshaped_image)
    reshaped_image = tf.image.random_brightness(reshaped_image, max_delta=63)
    # 随机改变图片的亮度
    reshaped_image = tf.image.random_contrast(reshaped_image, lower=0.2, upper=1.8)
    # 随机改变色彩对比
    reshaped_image = tf.image.per_image_standardization(reshaped_image)
    # 减均值除方差标准化
    reshaped_image.set_shape([args.crop_size[0], args.crop_size[1], args.num_channels])  # 设置张量大小
    return reshaped_image


def _test_preprocess(reshaped_image, args):  # 定义测试时的预处理函数
    resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, args. \
                                                           crop_size[0], args.crop_size[1])
    float_image = tf.image.per_image_standardization(resized_image)
    float_image.set_shape([args.crop_size[0], args.crop_size[1], args.num_channels])
    return float_image
# arch.py
import googlenet
def get_model(inputs,wd,is_training,args,transferModel=False):
    if args.architecture=='googlenet':
        return googlenet.inference(inputs,args.num_classes,wd,
                                                 0.4 if is_training else 1.0,
                                                 is_training,transferModel)
# common.py
import tensorflow as tf
import re
from tensorflow.python.training import moving_averages
from tensorflow.python.ops import control_flow_ops
from math import sqrt

RESNET_VARIABLES = 'resnet_variables'
TOWER_NAME = 'Tower'


def _get_variable(name, shape, initializer, regularizer=None, dtype='float',
                  trainable=True):  # 初始化所有变量
    collections = [tf.GraphKeys.GLOBAL_VARIABLES, RESNET_VARIABLES]
    # 存储数据流图变量
    with tf.device('/cpu:0'):
        var = tf.get_variable(name, shape=shape, initializer=initializer,
                              dtype=dtype, regularizer=regularizer,
                              collections=collections, trainable=trainable)
    return var


def batchNormalization(x, is_training=True, decay=0.9, epsilon=0.001):
    # BN归一化函数
    x_shape = x.get_shape()
    params_shape = x_shape[-1:]
    axis = list(range(len(x_shape) - 1))
    beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer)
    gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer)
    moving_mean = _get_variable('moving_mean', params_shape,
                                initializer=tf.zeros_initializer, trainable=False)
    moving_variance = _get_variable('moving_variance', params_shape,
                                    initializer=tf.ones_initializer, trainable=False)
    if is_training:
        mean, variance = tf.nn.moments(x, axis)
        update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean,
                                                                   decay)
        update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_mean)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_variance)
        return tf.nn.batch_normalization(x, mean, variance, beta, gamma, epsilon)
    else:
        return tf.nn.batch_normalization(x, moving_mean, moving_variance, beta,
                                         gamma, epsilon)


def flatten(x):  # 平铺函数,把张量改成向量
    shape = x.get_shape().as_list()
    dim = 1
    for i in range(1, len(shape)):
        dim *= shape[i]
    return tf.reshape(x, [-1, dim])


def treshold(x, treshold):  # x大于阈值时保持不变,否则置为0
    return tf.cast(x > treshold, x.dtype) * x


def fullyConnected(x, num_units_out, wd=0.0, weight_initializer=None, bias_initializer=None):  # 定义全连接层
    num_units_in = x.get_shape()[1]
    stddev = 1./ tf.sqrt(tf.cast(num_units_out, tf.float32))
    if weight_initializer is None:
        weight_initializer = tf.random_uniform_initializer(minval=-stddev,
                                                           maxval=stddev, dtype=tf.float32)
    if bias_initializer is None:
        bias_initializer = tf.random_uniform_initializer(minval=-stddev, maxval=stddev,
                                                         dtype=tf.float32)
    weights = _get_variable('weights', [num_units_in, num_units_out], weight_initializer,
                            tf.contrib.layers.l2_regularizer(wd))
    biases = _get_variable('biases', [num_units_out], bias_initializer)
    return tf.nn.xw_plus_b(x, weights, biases)


def spatialConvolution(x, ksize, stride, filters_out, wd=0.0, weight_initializer=None,
                       bias_initializer=None):
    filters_in = x.get_shape()[-1]
    stddev = 1. / tf.sqrt(tf.cast(filters_out, tf.float32))
    if weight_initializer is None:
        weight_initializer = tf.random_uniform_initializer(minval=-stddev,
                                                           maxval=stddev, dtype=tf.float32)
    if bias_initializer is None:
        bias_initializer = tf.random_uniform_initializer(minval=-stddev, maxval=stddev,
                                                         dtype=tf.float32)
    shape = [ksize, ksize, filters_in, filters_out]
    weights = _get_variable('weights', shape, weight_initializer, tf.contrib.layers.l2_regularizer(wd))
    conv = tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding='SAME')
    biases = _get_variable('biases', [filters_out], bias_initializer)
    return tf.nn.bias_add(conv, biases)


def maxPool(x, ksize, stride):  # 定义最大池化函数
    return tf.nn.max_pool(x, ksize=[1, ksize, ksize, 1], strides=[1, stride,
                                                                  stride, 1], padding='SAME')


def avgPool(x, ksize, stride):  # 定义平均池化函数
    return tf.nn.avg_pool(x, ksize=[1, ksize, ksize, 1], strides=[1, stride,
                                                                  stride, 1], padding='SAME')
# googlenet.py
import tensorflow as tf
import common


def inception(x, conv1_size, conv3_size, conv5_size, pool1_size, wd, is_training):  # 定义Inception模块
    with tf.variable_scope("conv_1"):
        conv1 = common.spatialConvolution(x, 1, 1, conv1_size, wd=wd)
        # 卷积运算
        conv1 = common.batchNormalization(conv1, is_training=is_training)
        # BN归一化
        conv1 = tf.nn.relu(conv1)  # 使用ReLU激活函数
    with tf.variable_scope("conv_3_1"):
        conv3 = common.spatialConvolution(x, 1, 1, conv3_size[0], wd=wd)
        conv3 = common.batchNormalization(conv3, is_training=is_training)
        conv3 = tf.nn.relu(conv3)
    with tf.variable_scope("conv_3_2"):
        conv3 = common.spatialConvolution(conv3, 3, 1, conv3_size[1], wd=wd)
        conv3 = common.batchNormalization(conv3, is_training=is_training)
        conv3 = tf.nn.relu(conv3)
    with tf.variable_scope("conv_5_1"):
        conv5 = common.spatialConvolution(x, 1, 1, conv5_size[0], wd=wd)
        conv5 = common.batchNormalization(conv5, is_training=is_training)
    conv5 = tf.nn.relu(conv5)
    with tf.variable_scope("conv_5_2"):
        conv5 = common.spatialConvolution(conv5, 5, 1, conv5_size[1], wd=wd)
        conv5 = common.batchNormalization(conv5, is_training=is_training)
        conv5 = tf.nn.relu(conv5)
    with tf.variable_scope("pool_1"):
        pool1 = common.maxPool(x, 3, 1)
        pool1 = common.spatialConvolution(pool1, 1, 1, pool1_size, wd=wd)
        pool1 = common.batchNormalization(pool1, is_training=is_training)
        pool1 = tf.nn.relu(pool1)
    return tf.concat([conv1, conv3, conv5, pool1], 3)  # 进行拼接并返回


def inference(x, num_output, wd, dropout_rate, is_training, transfer_mode= \
        False):  # 定义GoogLeNet的结构
    with tf.variable_scope('features'):  # 主分类器和辅助分类器的公共结构
        with tf.variable_scope('conv1'):
            network = common.spatialConvolution(x, 7, 2, 64, wd=wd)
            network = common.batchNormalization(network, is_training=is_training)
            network = tf.nn.relu(network)
        network = common.maxPool(network, 3, 2)
        with tf.variable_scope('conv2'):
            network = common.spatialConvolution(network, 1, 1, 64, wd=wd)
            network = common.batchNormalization(network, is_training=is_training)
            network = tf.nn.relu(network)
        with tf.variable_scope('conv3'):
            network = common.spatialConvolution(network, 3, 1, 192, wd=wd)
            network = common.batchNormalization(network, is_training=is_training)
            network = tf.nn.relu(network)
        network = common.maxPool(network, 3, 2)
        with tf.variable_scope('inception3a'):
            network = inception(network, 64, [96, 128], [16, 32], 32, wd=wd,
                                is_training=is_training)
            with tf.variable_scope('inception3b'):
                network = inception(network, 128, [128, 192], [32, 96], 64, wd=wd,
                                    is_training=is_training)
            network = common.maxPool(network, 3, 2)
            with tf.variable_scope('inception4a'):
                network = inception(network, 192, [96, 208], [16, 48], 64, wd=wd,
                                    is_training=is_training)
            with tf.variable_scope('inception4b'):
                network = inception(network, 160, [112, 224], [24, 64], 64, wd=wd,
                                    is_training=is_training)
            with tf.variable_scope('inception4c'):
                network = inception(network, 128, [128, 256], [24, 64], 64, wd=wd,
                                    is_training=is_training)
            with tf.variable_scope('inception4d'):
                network = inception(network, 112, [144, 288], [32, 64], 64, wd=wd,
                                    is_training=is_training)
        with tf.variable_scope('mainb'):  # 主分类器
            with tf.variable_scope('inception4e'):
                main_branch = inception(network, 256, [160, 320], [32, 128], 128,
                                        wd=wd, is_training=is_training)
            main_branch = common.maxPool(main_branch, 3, 2)
            with tf.variable_scope('inception5a'):
                main_branch = inception(main_branch, 256, [160, 320], [32, 128],
                                        128, wd=wd, is_training=is_training)
            with tf.variable_scope('inception5b'):
                main_branch = inception(main_branch, 384, [192, 384], [48, 128],
                                        128, wd=wd, is_training=is_training)
            main_branch = common.avgPool(main_branch, 7, 1)
            main_branch = common.flatten(main_branch)
            main_branch = tf.nn.dropout(main_branch, dropout_rate)
            if not transfer_mode:
                with tf.variable_scope('output'):
                    main_branch = common.fullyConnected(main_branch, num_output, wd=wd)
            else:
                with tf.variable_scope('transfer_output'):
                    main_branch = common.fullyConnected(main_branch, num_output, wd=wd)
    with tf.variable_scope('auxb'):  # 辅助分类器
        aux_classifier = common.avgPool(network, 5, 3)
        with tf.variable_scope('conv1'):
            aux_classifier = common.spatialConvolution(aux_classifier, 1, 1, 128,
                                                       wd=wd)
            aux_classifier = common.batchNormalization(aux_classifier, is_training= \
                is_training)
            aux_classifier = tf.nn.relu(aux_classifier)
        aux_classifier = common.flatten(aux_classifier)
        with tf.variable_scope('fc1'):
            aux_classifier = common.fullyConnected(aux_classifier, 1024, wd=wd)
            aux_classifier = tf.nn.dropout(aux_classifier, dropout_rate)
        if not transfer_mode:
            with tf.variable_scope('output'):
                aux_classifier = common.fullyConnected(aux_classifier, num_output,
                                                       wd=wd)
        else:
            with tf.variable_scope('transfer_output'):
                aux_classifier = common.fullyConnected(aux_classifier, num_output,
                                                       wd=wd)
    return tf.concat([main_branch, aux_classifier], 1)  # 将主分支和辅助分类器拼接
# eval.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import math
import time
import os
import numpy as np
import tensorflow as tf
import argparse
import arch
import data_loader
import sys


def evaluate(args):  # 评价函数
    with tf.Graph().as_default() as g, tf.device('/cpu:0'):  # 建立数据图结构
        if args.save_predictions is None:  # 得到图像和对应的标签
            images, labels = data_loader.read_inputs(False, args)
        else:
            images, labels, urls = data_loader.read_inputs(False, args)
        with tf.device('/gpu:0'):  # 在GPU上执行计算
            logits = arch.get_model(images, 0.0, False, args)  # 计算网络预测结果
            top_1_op = tf.nn.in_top_k(logits, labels, 1)  # 计算top-1的预测准确率
            top_5_op = tf.nn.in_top_k(logits, labels, 5)  # 计算top-5的预测准确率
            if args.save_predictions is not None:
                top5 = tf.nn.top_k(tf.nn.softmax(logits), 5)
                top5ind = top5.indices
                top5val = top5.values
            saver = tf.train.Saver(tf.global_variables())
            summary_op = tf.summary.merge_all()
            summary_writer = tf.summary.FileWriter(args.log_dir, g)
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            ckpt = tf.train.get_checkpoint_state(args.log_dir)  # 加载训练好的模型
            if ckpt and ckpt.model_checkpoint_path:  # 加载最新模型
                saver.restore(sess, ckpt.model_checkpoint_path)  # 从检测点还原
            else:
                return
            coord = tf.train.Coordinator()  # 启动队列运行程序
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            true_predictions_count = 0  # 统计正确预测的数量
            true_top5_predictions_count = 0
            step = 0
            predictions_format_str = ('%d, %s, %d, %s, %s\n')
            batch_format_str = ('Batch Number: %d, Top-1 Hit: %d, Top-5 Hit:'
                                '% d, Top - 1 Accuracy: % .3f, Top-5 Accuracy: % .3f')
            if args.save_predictions is not None:
                out_file = open(args.save_predictions, 'w')
            while step < args.num_batches and not coord.should_stop():
                if args.save_predictions is None:
                    top1_predictions, top5_predictions = sess.run([top_1_op, top_5_op])
                else:
                    top1_predictions, top5_predictions, urls_values, label_values,\
                    top5guesses,top5conf = sess.run([top_1_op, top_5_op, urls, labels,
                                         top5ind, top5val])
                for i in range(0, urls_values.shape[0]):
                    out_file.write(predictions_format_str % (step*args.batch_size+i+1,
                                                             urls_values[i], label_values[i],
                                                             '[' + ', '
                                                             .join('%d'% item for item in top5guesses[i])
                                                             + ']', '[' + ','
                                                             .join(' % .4f'  %  item
                                                                   for  item  in  top5conf[i]) + ']'))
                    out_file.flush()
                    true_predictions_count += np.sum(top1_predictions)
                    true_top5_predictions_count += np.sum(top5_predictions)
                    print(batch_format_str % (step, true_predictions_count, true_top5_predictions_count,
                                              true_predictions_count / ((step + 1.0) * args.batch_size),
                                              true_top5_predictions_count / ((step + 1.0) * args.batch_size)))
                    sys.stdout.flush()
                    step += 1
                    if args.save_predictions is not None:
                        out_file.close()
                    summary = tf.Summary()
                    summary.ParseFromString(sess.run(summary_op))
                    coord.request_stop()
                    coord.join(threads)


def main():  # 定义评价主函数
    parser = argparse.ArgumentParser(description='Process Command-line Arguments')
    # 创建解析器
    parser.add_argument('--load_size', nargs=2, default=[256, 256], type=int,
                        action='store', help='The width and height of images for loading from disk')
    parser.add_argument('--crop_size', nargs=2, default=[224, 224], type=int,
                        action='store', help='The width and height of images after random cropping')
    parser.add_argument('--batch_size', default=100, type=int, action='store',
                        help='The testing batch size')
    parser.add_argument('--num_classes', default=17, type=int, action='store',
                        help='The number of classes')
    parser.add_argument('--num_channels', default=3, type=int, action='store',
                        help='The number of channels in input images')
    parser.add_argument('--num_batches', default=-1, type=int, action='store',
                        help='The number of batches of data')
    parser.add_argument('--path_prefix', default='./', action='store', help= \
        'The prefix address for images')
    parser.add_argument('--delimiter', default=' ', action='store',
                        help='Delimite for the input files')
    parser.add_argument('--data_info', default= 'val.txt', action= 'store',
    help= 'File containing the addresses and labels of testing images')
    parser.add_argument('--num_threads', default= 20, type= int, action= 'store',
    help= 'The number of threads for loading data')
    parser.add_argument('--architecture', default= 'resnet', help='The DNN architecture')
    parser.add_argument('--depth', default= 50, type= int, help= 'The depth of ResNet architecture')
    parser.add_argument('--log_dir', default= None, action= 'store',
    help='Path for saving Tensorboard info and checkpoints')
    parser.add_argument('--save_predictions', default= None, action= 'store',
    help= 'Save top-5 predictions of the networks along with their confidence in the specified file')
    args = parser.parse_args()  # 解析参数列表
    args.num_samples = sum(1 for line in open(args.data_info))
    if args.num_batches == -1:
        if(args.num_samples % args.batch_size == 0):
            args.num_batches = int(args.num_samples / args.batch_size)
        else:
            args.num_batches = int(args.num_samples / args.batch_size) + 1
    print(args)
    evaluate(args)
    if __name__ == '__main__':
        main()

使用方法:将数据集放入data文件夹下,然后在控制台输入python train.py --path_prefix ./data/train

你可能感兴趣的:(自我学习归纳,tensorflow,机器学习,神经网络,深度学习)