tfrecord数据集训练及验证

tfrecord数据集训练及验证

导入必要的库

#!/usr/bin/env python
# -*- coding:utf-8 -*-

from mk_tfrecord import *
#from model import *
from inception_v3 import *
import numpy as np
import os
import cv2

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

定义训练函数

def training():
    N_CLASSES = 2              # 分类数目
    IMG_W = 299                # 统一图片大小,宽度
    IMG_H = 299                # 统一图片大小,高度
    BATCH_SIZE = 64            # 批次大小
    MAX_STEP = 50000           # 迭代次数
    LEARNING_RATE = 0.0001     # 学习率
    min_after_dequeue = 1000

    tfrecord_filename = '/home/xieqi/project/cat_dog/train.tfrecords'   # 训练数据集
    logs_dir = '/home/xieqi/project/cat_dog/log_v3'     # 检查点保存路径

    # 输入--要生成的字符串的一维字符串张量,shuffle默认为True,输出--字符串队列
    # 将字符串(例如文件名)输出到输入管道的队列,不限制num_epoch。
    filename_queue = tf.train.string_input_producer([tfrecord_filename], num_epochs=150)
    train_image, train_label = read_and_decode(filename_queue, image_W=IMG_W, image_H=IMG_H,
                batch_size=BATCH_SIZE,min_after_dequeue=min_after_dequeue) # 返回的为tensor

    train_labels = tf.one_hot(train_label, N_CLASSES)

    train_logits,_ = inception_v3(train_image,num_classes=N_CLASSES)
    train_loss = loss(train_logits, train_labels) # 损失函数
    train_acc = accuracy(train_logits, train_labels) # 模型精确度
    my_global_step = tf.Variable(0, name='global_step', trainable=False) # 全局步长
    train_op = optimize(train_loss, LEARNING_RATE, my_global_step) #训练模型

    summary_op = tf.summary.merge_all() # 收集模型统计信息
    # 初始化全局变量和局部变量
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    # 限制GPU使用率
    # sess_config = tf.ConfigProto()
    # sess_config.gpu_options.per_process_gpu_memory_fraction = 0.70
    # sess = tf.Session(config=sess_config)

    sess = tf.Session()
    # FileWriter类提供了一个机制来创建指定目录的事件文件,并添加摘要和事件给它(异步更新,不影响训练速度)
    train_writer = tf.summary.FileWriter(logs_dir, sess.graph)
    # 将Save类添加OPS保存和恢复变量和检查点。对模型定期做checkpoint,通常用于模型恢复
    saver = tf.train.Saver()

    sess.run(init_op)
    # 线程协调员, 实现一种简单的机制来协调一组线程的终止
    coord = tf.train.Coordinator()
    # 启动图中收集的所有队列, 开始填充队列
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    try:
        for step in range(MAX_STEP):
            if coord.should_stop():
                break

            # 获取一个批次的数据及标签
            image_batch, label_batch = sess.run([train_image, train_label])
            sess.run(train_op)

            # 每迭代100次计算一次loss和准确率
            if step % 100 == 0:
                losses, acc = sess.run([train_loss, train_acc])
                print('Step: %6d, loss: %.8f, accuracy: %.2f%%' % (step, losses, acc))
                summary_str = sess.run(summary_op)
                train_writer.add_summary(summary_str, step)

            if step % 1000 == 0 or step == MAX_STEP - 1:  # 保存检查点
                checkpoint_path = os.path.join(logs_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

    except tf.errors.OutOfRangeError:
        print('Done.')
    finally:
        coord.request_stop()

    coord.join(threads=threads)
    sess.close()

定义验证函数

# 测试检查点
def eval():
    N_CLASSES = 2
    IMG_W = 299
    IMG_H = 299
    BATCH_SIZE = 1
    MAX_STEP = 512
    min_after_dequeue=0

    test_dir = '/home/xieqi/project/cat_dog/val.tfrecords' #测试集数据
    logs_dir = '/home/xieqi/project/cat_dog/log_v3'     # 检查点目录
    false_pic_dir = '/home/xieqi/project/cat_dog/false_pic/' #错误分类的图片存储地址

    # 输入要生成的字符串的一维字符张量,输出字符串队列,shuffle默认为True
    filename_queue = tf.train.string_input_producer([test_dir], num_epochs=1)
    train_image, train_label = read_and_decode(filename_queue, image_W=IMG_W, image_H=IMG_H,
                batch_size=BATCH_SIZE,min_after_dequeue=min_after_dequeue) # 返回的为tensor

    train_labels = tf.one_hot(train_label, N_CLASSES)

    train_logits, _ = inception_v3(train_image, N_CLASSES)
    train_logits = tf.nn.softmax(train_logits)  # 用softmax转化为百分比数值

    #计算准确率
    correct_num = tf.placeholder('float')
    correct_pre = tf.div(correct_num, MAX_STEP)

    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess = tf.Session()
    sess.run(init_op)
    # 载入检查点
    saver = tf.train.Saver()
    print('\n载入检查点...')

    # 通过checkpoint文件找到模型文件名,有两个属性:model_checkpoint_path最新的模型文件的文件名
    # all_model_checkpoint_paths未被删除的所有模型文件的文件名
    ckpt = tf.train.get_checkpoint_state(logs_dir)

    if ckpt and ckpt.model_checkpoint_path:
        global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('载入成功,global_step = %d\n' % global_step)
    else:
        print('没有找到检查点')
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        correct = 0
        wrong = 0
        dt_list = []
        for step in range(MAX_STEP):

            if coord.should_stop():
                break

            st = time.time()
            image, prediction, labels = sess.run([train_image, train_logits, train_labels])
            dt = time.time() - st
            dt_list.append(dt)

            p_max_index = np.argmax(prediction)
            c_max_index = np.argmax(labels)

            if p_max_index == c_max_index:
                for i in range(BATCH_SIZE):
                    correct += 1
            else:
                for i in range(BATCH_SIZE):
                    wrong += 1
                    cv2.imwrite(false_pic_dir+'ture'+str(labels)+'predict'+ \
                                str(prediction)+'.jpg', image[i])

        accuray_rate = sess.run(correct_pre,feed_dict={correct_num: correct})
        velocity = np.mean(dt_list)
        print('Total: %5d, correct: %5d, wrong: %5d, accuracy: %3.2f%%, each speed: %.4fs' %
            (MAX_STEP, correct, wrong, accuray_rate * 100, velocity))
    except tf.errors.OutOfRangeError:
        print('OutOfRange')
    finally:
        coord.request_stop()

    coord.join(threads=threads)
    sess.close()

主函数

if __name__ == '__main__':
    training()
    #eval()

你可能感兴趣的:(tfrecord数据集训练及验证)