分类任务6——将前面的综合成训练,验证模块

参照上一篇这里上两份代码

没有batch_norm

from using_dataset_to_read_the_tfrecord4 import *
from AlexNet_5 import *
import matplotlib.pyplot as plt
from set_config0 import *
import numpy as np
from tqdm import tqdm


def main():

    train_file_name = 'E:/111project/tfrecord/train.tfrecords'
    validation_file_name = 'E:/111project/tfrecord/validation.tfrecords'

    train_data = create_dataset(train_file_name, batch_size=batch_size,
                                resize_height= resize_height, resize_width=resize_width, num_class=num_class)
    validation_data = create_dataset(validation_file_name, batch_size=batch_size,
                                     resize_height=resize_height, resize_width=resize_width, num_class=num_class)

    train_iterator = train_data.make_initializable_iterator()
    # val_iterator = validation_data.make_initializable_iterator()
    # val_iterator = validation_data.make_one_shot_iterator()
    val_iterator = tf.data.Iterator.from_structure(validation_data.output_types,
                                                   validation_data.output_shapes)
    val_op = val_iterator.make_initializer(validation_data)

    train_images, train_labels = train_iterator.get_next()
    val_images, val_labels = val_iterator.get_next()

    x = tf.placeholder(tf.float32, shape=[None, resize_height, resize_width, 3], name='x')
    y = tf.placeholder(tf.int32, shape=[None, num_class],name='y')
    keep_prob = tf.placeholder(tf.float32)

    fc3, parameters = inference(x, num_class, keep_prob)

    # learning rate
    with tf.name_scope('learning_rate'):
        global_ = tf.Variable(tf.constant(0))
        lr = tf.train.exponential_decay(learning_rate, global_,
                                       decay_step, decay_rate, staircase=True)

    # loss
    with tf.name_scope('loss'):
        loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=fc3, labels=y))

    # optimizer
    with tf.name_scope('optimizer'):
        # optimizer = tf.train.GradientDescentOptimizer(lr)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss_op, global_step=global_)

    # accuracy
    with tf.name_scope("accuracy"):
        correct_pred = tf.equal(tf.argmax(fc3, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Tensorboard
    train_tensorboard = 'E:/log/train/'
    val_tensorboard = 'E:/log/val/'
    tf.summary.scalar('loss', loss_op)
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.scalar('learning_rate', lr)
    merged_summary = tf.summary.merge_all()
    # train_writer = tf.summary.FileWriter(train_tensorboard, sess.graph)
    # val_writer = tf.summary.FileWriter(val_tensorboard)

    # saver
    saver = tf.train.Saver()

    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)

        train_writer = tf.summary.FileWriter(train_tensorboard, sess.graph)
        val_writer = tf.summary.FileWriter(val_tensorboard)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        trainAcc = []
        valACC = []
        trainLoss=[]
        valLoss=[]
        lrate1 = []
        totalIteration = []
        for i in range(iteration):
            print('iteration: {}'.format(i+1))
            totalIteration.append(i+1)
            sess.run(train_iterator.initializer)
            sess.run(val_op)

            try:
                train_batch_images, train_batch_labels = sess.run([train_images, train_labels])

                train_loss, lr1, _, train_acc = sess.run([loss_op, lr, train_op, accuracy],
                                                         feed_dict={x: train_batch_images,
                                                                    y: train_batch_labels,
                                                                    keep_prob: drop_rate,
                                                                    #global_: i
                                                                    })
                print('lr is : {}\n' .format(lr1))
                print("train loss: %.8f, train acc:%.8f" % (train_loss, train_acc))

                s = sess.run(merged_summary, feed_dict={x: train_batch_images,
                                                        y: train_batch_labels,
                                                        keep_prob: 1.0,
                                                        global_: i})
                train_writer.add_summary(summary=s, global_step=i)

                trainAcc.append(train_acc)

                trainLoss.append(train_loss)

                lrate1.append(lr1)

                val_batch_images, val_batch_labels = sess.run([val_images, val_labels])
                val_loss, val_acc = sess.run([loss_op, accuracy], feed_dict={x: val_batch_images,
                                                                             y: val_batch_labels,
                                                                             keep_prob: 1.0})

                t = sess.run(merged_summary, feed_dict={x: val_batch_images,
                                                        y: val_batch_labels,
                                                        keep_prob: 1.0})
                val_writer.add_summary(summary=t, global_step=i)

                print("val loss: %.8f, val acc: %.8f" % (val_loss, val_acc))
                print('\n')
                valACC.append(val_acc)
                valLoss.append(val_loss)

            except tf.errors.OutOfRangeError:
                break

        coord.request_stop()
        coord.join(threads)
        plt.figure('aa')
        plt.plot(totalIteration, trainAcc, 'r', label='train_acc')
        plt.plot(totalIteration, valACC, 'b', label='val_acc')
        plt.ylim((0, 1))
        plt.legend()

        plt.figure('bb')
        plt.plot(totalIteration, trainLoss, 'r', label='train_loss')
        plt.plot(totalIteration, valLoss, 'b', label='val_loss')
        plt.legend()

        plt.figure('cc')
        plt.plot(totalIteration, lrate1, 'g', label='lr')
        plt.legend()

        plt.show()


if __name__ == '__main__':
    main()

添加了batch_norm的:

from using_dataset_to_read_the_tfrecord_4 import *
from AlexNet_with_batchnorm_5 import *
from set_config_0 import *
from tqdm import tqdm


def main():
    train_file_name = 'E:/111project/tfrecord/train.tfrecords'
    validation_file_name = 'E:/111project/tfrecord/validation.tfrecords'

    train_data = create_dataset(train_file_name, batch_size=batch_size,
                                resize_height=resize_height, resize_width=resize_width, num_class=num_class)
    validation_data = create_dataset(validation_file_name, batch_size=batch_size,
                                     resize_height=resize_height, resize_width=resize_width, num_class=num_class)

    train_data = train_data.repeat()
    validation_data = validation_data.repeat()

    train_iterator = train_data.make_one_shot_iterator()
    val_iterator = validation_data.make_one_shot_iterator()

    train_images, train_labels = train_iterator.get_next()
    val_images, val_labels = val_iterator.get_next()

    x = tf.placeholder(tf.float32, shape=[None, resize_height, resize_width, 3], name='x')
    y = tf.placeholder(tf.int32, shape=[None, num_class], name='y')
    keep_prob = tf.placeholder(tf.float32)
    is_training = tf.placeholder(tf.bool)

    fc3 = inference(x, num_class, keep_prob, is_training)

    with tf.name_scope('learning_rate'):
        # global_ = tf.Variable(tf.constant(0))
        global_ = tf.placeholder(tf.int32)
        lr = tf.train.exponential_decay(learning_rate, global_,
                                        decay_step, decay_rate, staircase=True)

    with tf.name_scope('loss'):
        loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=fc3, labels=y))

    with tf.name_scope('optimizer'):
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss_op)

    with tf.name_scope("accuracy"):
        correct_pred = tf.equal(tf.argmax(fc3, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    train_tensorboard = 'E:/log/train/'
    val_tensorboard = 'E:/log/val/'
    tf.summary.scalar('loss', loss_op)
    tf.summary.scalar('accuracy', accuracy)
    # tf.summary.scalar('learning_rate', lr)
    merged_summary = tf.summary.merge_all()

    saver = tf.train.Saver(max_to_keep=3)

    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)
        train_writer = tf.summary.FileWriter(train_tensorboard, sess.graph)
        val_writer = tf.summary.FileWriter(val_tensorboard)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        max_acc = 0

        for i in tqdm(range(iteration)):
            print('\n')
            print('\niteration: {}'.format(i + 1))
            # train_batch_images, train_batch_labels = sess.run([train_images, train_labels])
            # val_batch_images, val_batch_labels = sess.run([val_images, val_labels])

            # if (i+1) % 10 == 0:
            #     print('\niteration: {}'.format(i + 1))
            try:
                train_batch_images, train_batch_labels = sess.run([train_images, train_labels])
                train_loss, lr1, _, train_acc = sess.run([loss_op, lr, train_op, accuracy],
                                                         feed_dict={x: train_batch_images,
                                                                    y: train_batch_labels,
                                                                    keep_prob: drop_rate,
                                                                    is_training: True,
                                                                    global_: i
                                                                    })

                val_batch_images, val_batch_labels = sess.run([val_images, val_labels])
                val_loss, val_acc = sess.run([loss_op, accuracy], feed_dict={x: val_batch_images,
                                                                             y: val_batch_labels,
                                                                             keep_prob: 1.0,
                                                                             is_training: False,
                                                                             global_: i
                                                                             })
                print('lr is : {}'.format(lr1))
                print("train loss: %.6f, train acc:%.6f" % (train_loss, train_acc))
                s = sess.run(merged_summary, feed_dict={x: train_batch_images,
                                                        y: train_batch_labels,
                                                        keep_prob: drop_rate,
                                                        is_training: True,
                                                        global_: i
                                                        })
                train_writer.add_summary(summary=s, global_step=i)

                print("val loss: %.6f, val acc: %.6f" % (val_loss, val_acc))
                print('\n')
                t = sess.run(merged_summary, feed_dict={x: val_batch_images,
                                                        y: val_batch_labels,
                                                        keep_prob: 1.0,
                                                        is_training: False,
                                                        global_: i
                                                        })
                val_writer.add_summary(summary=t, global_step=i)

                if val_acc > max_acc:
                    max_acc = val_acc
                    saver.save(sess, 'model/alexnet.ckpt', global_step=i)

            except tf.errors.OutOfRangeError:
                break

        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    main()

你可能感兴趣的:(项目过程)