R-CNN原理详解与代码超详细讲解(四)--train_predict代码讲解

R-CNN原理详解与代码超详细讲解(四)–train_predict代码讲解

config代码

IMAGE_WIDTH = 227
IMAGE_HEIGHT = 227
IMAGE_CHANNEL = 3
CLASS_NUMBER = 3

ALEX_NET_MAT_FILE_PATH = "C:/Users/user/Desktop/05_rcnn/AlexNet预加载模型/imagenet-caffe-alex.mat"

ORIGINAL_FINE_TUNE_DATA_FILE_PATH = r'C:\Users\user\Desktop\05_rcnn\img_datas\fine_tune_list.txt' #微调训练数据路径
TRAIN_DATA_FILE_PATH = './datas/traning_data.npy' #提取出的datas存放路径
TRAIN_LABEL_DICT_FILE_PATH = './datas/label_dict.pkl' #svm训练标签文件路径,格式为:{'2': 1, '1': 2}

FINE_TUNE_SUMMARY_WRITER_LOG_DIR = './output/graph/fine_tune'
FINE_TUNE_CHECKPOINT_DIR = './output/models/fine_tune'
FINE_TUNE_CHECKPOINT_FILENAME = 'models.ckpt'
FINE_TUNE_MAX_STEP = 10000
FINE_TUNE_SUMMARY_STEP = 10
FINE_TUNE_CHECKPOINT_STEP = 50
FINE_TUNE_INITIAL_LEARNING_RATE = 0.001
FINE_TUNE_DECAY_STEPS = 1000
FINE_TUNE_DECAY_RATE = 0.99
FINE_TUNE_IOU_THRESHOLD = 0.5  #微调训练的时候正负样本区别的iou的大小
FINE_TUNE_POSITIVE_BATCH_SIZE = 8 #微调训练的时候正样本的批次大小
FINE_TUNE_NEGATIVE_BATCH_SIZE = 24 #微调训练的时候负样本的批次大小

TRAIN_SVM_HIGHER_FEATURES_DATA_FILE_PATH = "./datas/svm/higher_features_{}.npy" #svm训练所需要的cnn网络所提取的高阶特征,'shape': (10, 4097)
SVM_CHECKPOINT_FILE_PATH = "./output/models/svm/model_{}.pkl" #具体哪个类别训练的svm模型存放路径

train_predict代码

class SolverType(object):
    TRAIN_FINE_TUNE_MODEL = 0  # fine tune 的AlexNet微调训练
    GENERATE_TRAIN_SVM_FEATURES = 1  # 获取SVM训练用的高阶特征
    TRAIN_SVM_MODEL = 2  # 训练SVM模型
    GENERATE_TRAIN_REGRESSION_FEATURES = 3  # 获取训练回归模型的高阶特征
    TRAIN_REGRESSION_MODEL = 4  # 训练回归模型
    PREDICT_BOUNDING_BOX = 5  # 表示整个预测过程
    PREDICT_BOUNDING_BOX_STEP1 = 6  # 表示预测过程的第一步:获取SS边框
    PREDICT_BOUNDING_BOX_STEP2 = 7  # 表示预测过程的第二步:获取边框的高阶特征值
    PREDICT_BOUNDING_BOX_STEP3 = 8  # 表示预测过程的第三步:获取SVM预测结果
    PREDICT_BOUNDING_BOX_STEP4 = 9  # 表示预测过程的第四步:获取回归预测结果


class Solver(object):
    def __init__(self, solver_type):
        # 输入的属性信息
        self.is_training = False
        self.is_svm = False

        if SolverType.TRAIN_FINE_TUNE_MODEL == solver_type:
            with tf.Graph().as_default():
                print("进行Fine Tune模型训练操作....")
                self.is_training = True
                self.net = AlexNet(alexnet_mat_file_path=cfg.ALEX_NET_MAT_FILE_PATH,
                                   is_training=self.is_training)
                self.data_loader = FlowerDataLoader()

                # 设置相关配置文件
                self.__set_fine_tune_config()

                # 输出文件路径的检查
                check_directory(self.summary_writer_log_dir)
                check_directory(self.checkpoint_dir)

                # 构造全局step对象
                self.__get_or_create_global_step()

                # 构造训练对象
                self.__create_tf_train_op()

                # 构造持久化对象
                self.__create_tf_saver()

                # 构造可视化对象
                self.__create_tf_summary()

                # 构造会话以及初始化变量
                self.__create_tf_session_and_initial()

                # 设置运行的方法
                self.run = self.__fine_tune_train
        elif SolverType.GENERATE_TRAIN_SVM_FEATURES == solver_type:
            with tf.Graph().as_default():
                print("生成SVM训练用高阶特征属性,并持久化磁盘文件....")
                self.is_training = False
                self.is_svm = True
                self.net = AlexNet(alexnet_mat_file_path=cfg.ALEX_NET_MAT_FILE_PATH,
                                   is_training=self.is_training, is_svm=self.is_svm)
                self.data_loader = FlowerDataLoader()

                # 设置相关配置文件
                self.__set_fine_tune_config()

                # 输出文件路径的检查
                check_directory(self.summary_writer_log_dir)
                check_directory(self.checkpoint_dir, created=False, error=True)

                # 构造全局step对象
                self.__get_or_create_global_step()

                # 构造持久化对象
                self.__create_tf_saver()

                # 构造可视化对象
                self.__create_tf_summary()

                # 构造会话以及初始化变量
                self.__create_tf_session_and_initial()
                # 设置运行的方法
                self.run = self.__persistent_svm_higher_features
        elif SolverType.TRAIN_SVM_MODEL == solver_type:
            print("进行SVM模型训练操作....")
            self.is_svm = True
            self.is_training = True
            self.net = SVMModel(is_training=self.is_training)
            # 设置运行的方法
            self.run = self.__svm_train

    def __set_fine_tune_config(self):
        self.initial_learning_rate = cfg.FINE_TUNE_INITIAL_LEARNING_RATE #初始学习率
        self.decay_steps = cfg.FINE_TUNE_DECAY_STEPS#衰减步数
        self.decay_rate = cfg.FINE_TUNE_DECAY_RATE#衰减系数
        self.summary_writer_log_dir = cfg.FINE_TUNE_SUMMARY_WRITER_LOG_DIR#微调summary图存放地址
        self.checkpoint_dir = cfg.FINE_TUNE_CHECKPOINT_DIR#模型地址
        self.checkpoint_path = os.path.join(self.checkpoint_dir, cfg.FINE_TUNE_CHECKPOINT_FILENAME)
        self.max_steps = cfg.FINE_TUNE_MAX_STEP#微调训练最大步数
        self.summary_step = cfg.FINE_TUNE_SUMMARY_STEP#训练多少步骤保存一次summary
        self.checkpoint_step = cfg.FINE_TUNE_CHECKPOINT_STEP

    def __get_or_create_global_step(self):
        # 获取全局步长变量
        self.global_step = tf.train.get_or_create_global_step()

    def __create_tf_train_op(self):
        # 构造优化器
        if self.is_training:
            with tf.variable_scope("train"):
                # 模型更新的学习率
                self.learning_rate = tf.train.exponential_decay(
                    learning_rate=self.initial_learning_rate,
                    global_step=self.global_step,
                    decay_steps=self.decay_steps,
                    decay_rate=self.decay_rate,
                    name='learning_rate')
                tf.summary.scalar('learning_rate', self.learning_rate)

                self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) \
                    .minimize(self.net.total_loss, global_step=self.global_step)
                # 加一个模型参数滑动平均的操作,也就是将最近几个更新批次的参数更改为均值的方式
                self.ema = tf.train.ExponentialMovingAverage(0.99)
                with tf.control_dependencies([self.optimizer]):
                    self.train_op = self.ema.apply(tf.trainable_variables())

    def __create_tf_saver(self):
        # 模型持久化
        self.saver = tf.train.Saver()

    def __create_tf_summary(self):
        # 可视化对象
        self.summary = tf.summary.merge_all()
        self.writer = tf.summary.FileWriter(self.summary_writer_log_dir, graph=tf.get_default_graph())

    def __create_tf_session_and_initial(self):
        # 构造会话
        self.session = tf.Session()

        # 参数初始化
        self.session.run(tf.global_variables_initializer())

        # 如果模型存在的话,做一个模型恢复的操作
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print("进行模型恢复操作...")
            # 恢复模型
            self.saver.restore(self.session, ckpt.model_checkpoint_path)
            # 恢复checkpoint的管理信息
            self.saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths)

    def __svm_train(self):
        """
        进行SVM模型训练
        :return:
        """
        self.net.train()

    def __fine_tune_train(self):
        if not self.is_training:
            raise Exception("Train method request set 'is_training' parameter is True.")
        # 获取开始的步骤
        start_step = self.session.run(self.global_step)
        # 获取最大迭代步骤
        end_step = start_step + self.max_steps
        # 遍历进行训练
        for step in range(start_step, end_step):
            # 1. 获取数据
            images, labels = self.data_loader.get_fine_tune_batch()

            # 2. 模型训练
            feed_dict = {
     self.net.input_data: images, self.net.label: labels}

            # 3. 间隔性的进行可视化操作
            if step % self.summary_step == 0:
                summary_, loss_, accuracy_, _ = self.session.run(
                    [self.summary, self.net.total_loss, self.net.accuracy, self.train_op],
                    feed_dict=feed_dict)
                self.writer.add_summary(summary_, global_step=step)
                print("Training Step:{}, Loss:{}, Accuracy:{}".format(step, loss_, accuracy_))
            else:
                self.session.run(self.train_op, feed_dict=feed_dict)

            # 4. 间断性的做一个模型持久化的操作
            if step % self.checkpoint_step == 0:
                print("Saving model to {}".format(self.checkpoint_dir))
                self.saver.save(sess=self.session, save_path=self.checkpoint_path, global_step=step)

    def __fine_tune_predict(self, images):
        """
        运行,得到Fine Tune模型的返回值
        :param images:
        :return:
        """
        return self.session.run(self.net.logits, feed_dict={
     self.net.input_data: images})

    def __persistent_svm_higher_features(self):
        """
        持久化用于svm模型训练的高阶特征数据
        在svm模型训练中,是针对每个类别训练一个svm模型,所有在这里需要对于每个类别产生一个训练数据文件
        :return:
        """
        # 1. 获取标签值
        check_directory(cfg.TRAIN_LABEL_DICT_FILE_PATH, created=False, error=True) #格式为:{'2': 1, '1': 2}
        class_name_2_index_dict = pickle.load(open(cfg.TRAIN_LABEL_DICT_FILE_PATH, 'rb'))

        # 2. 遍历所有标签值

        for class_name, index in class_name_2_index_dict.items():
            print("Start process type '{}/{}' datas...".format(index, class_name))
            X = None
            Y = None

            # a. 获取当前标签对应的所有训练数据
            # TODO: 一般一个批次一个批次的进行数据的获取预测
            images, labels = self.data_loader.get_structure_higher_features(label=index)

            # b. 过滤数据
            if images is None or labels is None:
                print("没办法获取标签:{}对应的数据集!!!".format(index))
                continue

            # c. 调用预测代码,获取预测结果

            print(np.shape(images), np.shape(labels))
            higher_features = self.__fine_tune_predict(images)

            # d. 赋值
            X = higher_features
            Y = labels
            print("Final Feature Attribute Structure:{} - {}".format(np.shape(X), np.shape(Y)))
            print("Number of occurrences of each category:{}".format(collections.Counter(Y)))

            # e. 数据持久化保存
            # 合并数据
            data = np.concatenate((X, np.reshape(Y, (-1, 1))), axis=1)
            # 文件路径获取
            svm_higher_features_save_path = cfg.TRAIN_SVM_HIGHER_FEATURES_DATA_FILE_PATH.format(index)
            check_directory(os.path.dirname(svm_higher_features_save_path))
            # 数据输出
            np.save(svm_higher_features_save_path, data)


def run_solver():
    flag = 2
    if flag == 0:
        solver = Solver(SolverType.TRAIN_FINE_TUNE_MODEL)
    elif flag == 1:
        solver = Solver(SolverType.GENERATE_TRAIN_SVM_FEATURES)
    elif flag == 2:
        solver = Solver(SolverType.TRAIN_SVM_MODEL)
    solver.run()


if __name__ == '__main__':
    run_solver()

你可能感兴趣的:(目标检测,R-CNN,python,tensorflow,深度学习,可视化)