Tensorflow-花分类-图像再训练-part-4-整理翻译

继续前面的三篇文章Part-1,Part-2,Part-3,这一篇我们来完善存储和恢复机制。


把计算图保存到文件save_graph_to_file

下面是增加的代码,先不要运行,稍后一起测试:

#将图保存到文件,必要时创建允许的量子化    
def save_graph_to_file(graph, graph_file_name, module_spec, class_count):
    sess, _, _, _, _, _ = build_eval_session(module_spec, class_count)
    graph = sess.graph

    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), ['final_tensor_name'])

    with tf.gfile.FastGFile(graph_file_name, 'wb') as f:
        f.write(output_graph_def.SerializeToString())  

保存评估模型export_model

注意每次使用前必须把旧的saved_model文件夹删除或改名。

#导出评估eval图的模型pd文件用于提供服务
saved_model_dir=os.path.join(dir_path,'saved_model'+str(datetime.now()))
def export_model(module_spec, class_count):
    sess, in_image, _, _, _, _ = build_eval_session(module_spec, class_count)
    graph = sess.graph
    with graph.as_default():
        #输入输出点
        inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
        out_classes = sess.graph.get_tensor_by_name('final_tensor_name:0')
        outputs = {
            'prediction': tf.saved_model.utils.build_tensor_info(out_classes)
        }
        #创建签名
        signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs=inputs,
            outputs=outputs,
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

        #初始化
        legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

        #保存saved_model
        builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:signature
            },
            legacy_init_op=legacy_init_op)
    builder.save()


改进最终再训练函数

最后我们把run_final_retrain函数改进一下,增加评估eval和保存、导出功能。

注意每次使用前必须把旧的saved_model文件夹删除或改名。

以下是修改后的代码(这里参数train_steps=10,所以得到的模型精度也非常糟糕。如果您的计算机允许,官方默认是4000,请量力而为):

#保存概要和checkpoint路径设置
CHECKPOINT_NAME = os.path.join(dir_path,'checkpoints/retrain')
summaries_dir=os.path.join(dir_path,'summaries/train')
ensure_dir_exists(os.path.join(dir_path,'output')) 
saved_model_path=os.path.join(dir_path,'output/out_graph.pd')
output_label_path=os.path.join(dir_path,'output/labels.txt')

#执行训练兵保存checkpoint的函数
def run_final_retrain(train_steps=10,
             eval_step_interval=5,
             do_distort=True):
    module_spec = hub.load_module_spec(HUB_MODULE)
    
    #创建图并获取相关的张量入口
    graph, bottleneck_tensor, resized_image_tensor, wants_quantization = (
        create_module_graph(module_spec))
    
    with graph.as_default(): 
        #添加训练相关的张量和操作节点入口
        (train_step, cross_entropy, bottleneck_input,ground_truth_input,
         final_tensor) = add_final_retrain_ops(5, 'final_tensor_name', 
                                               bottleneck_tensor,wants_quantization,True)    

    with tf.Session(graph=graph) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        
        #添加图片解码相关的张量入口操作
        jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding(module_spec)
        
        #读取图片的bottleneck数据
        if do_distort:
            distorted_jpeg_data_tensor,distorted_image_tensor=add_input_distortions(module_spec,True,50,50,50)
        else:
            cache_bottlenecks(sess, 
                              jpeg_data_tensor,decoded_image_tensor, 
                              resized_image_tensor,bottleneck_tensor)           
            
        #创建评估新层精度的操作
        evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input)
        
        #记录概要信息与保存
        train_saver = tf.train.Saver()
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(summaries_dir+'/retrain',sess.graph)      
        validation_writer = tf.summary.FileWriter(summaries_dir + '/validation')        
        
        #开始运作!
        for i in range(train_steps):
            #获取图片bottleneck数据
            if do_distort:
                (train_bottlenecks,train_ground_truth) = get_random_distorted_bottlenecks(
                    sess,BATCH_SIZE, 'training',
                    distorted_jpeg_data_tensor,distorted_image_tensor, 
                    resized_image_tensor, bottleneck_tensor)
            else:
                (train_bottlenecks,train_ground_truth, _) = get_random_cached_bottlenecks(
                    sess, BATCH_SIZE, 'training',
                    jpeg_data_tensor,decoded_image_tensor, 
                    resized_image_tensor, bottleneck_tensor)
            
            #启动训练
            train_summary, _ = sess.run(
                [merged, train_step],
                feed_dict={bottleneck_input: train_bottlenecks,
                           ground_truth_input: train_ground_truth})
            train_writer.add_summary(train_summary, i)
            
            #间隔性启动评估
            is_last_step = (i + 1 == train_steps)
            if (i % eval_step_interval) == 0 or is_last_step:
                train_accuracy, cross_entropy_value = sess.run(
                    [evaluation_step, cross_entropy],
                    feed_dict={bottleneck_input: train_bottlenecks,
                               ground_truth_input: train_ground_truth})
                
                tf.logging.info('%s: Step %d: Train accuracy = %.1f%%' %(datetime.now(), i, train_accuracy * 100))
                tf.logging.info('%s: Step %d: Cross entropy = %f' %(datetime.now(), i, cross_entropy_value))
            
                #使用不同的bottleneck数据进行评估
                validation_bottlenecks, validation_ground_truth, _ = (
                    get_random_cached_bottlenecks(
                        sess, 10, 'validation', 
                        jpeg_data_tensor,decoded_image_tensor, 
                        resized_image_tensor, bottleneck_tensor))
                #启动评估!
                validation_summary, validation_accuracy = sess.run(
                    [merged, evaluation_step],
                    feed_dict={bottleneck_input: validation_bottlenecks,
                               ground_truth_input: validation_ground_truth})
                
                validation_writer.add_summary(validation_summary, i)
                tf.logging.info('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' %(datetime.now(), i, validation_accuracy * 100,len(validation_bottlenecks)))            
            
        
            #间隔保存中介媒体文件,为训练保存checkpoint
            if (i % eval_step_interval == 0 and i > 0):
                train_saver.save(sess, CHECKPOINT_NAME)
                intermediate_file_name = (os.path.join(dir_path + 'intermediate') + str(i) + '.pb')
                tf.logging.info('Save intermediate result to : '+intermediate_file_name)
                save_graph_to_file(graph, intermediate_file_name, module_spec,5)

        #保存模型    
        train_saver.save(sess, CHECKPOINT_NAME)        
        
        #执行最终评估
        run_final_eval(sess, module_spec, 5,
                       jpeg_data_tensor, decoded_image_tensor,
                       resized_image_tensor,bottleneck_tensor)

        
        tf.logging.info('Save final result to : ' + saved_model_path)
        if wants_quantization:
            tf.logging.info('The model is instrumented for quantization with TF-Lite')
        save_graph_to_file(graph, saved_model_path, module_spec, 5)
        with tf.gfile.FastGFile(output_label_path, 'w') as f:
            f.write('\n'.join(image_lists.keys()) + '\n')
        export_model(module_spec, 5)

案例小结

这个案例来自Tensorflow官方教程,之前两个相对都比较简单,代码量只有100行左右,这个案例官方原代码突然有1300行之多,大有才学了十以内加减法然后就讲微积分方程的感觉。

这里整个案例去掉了很多官方代码中我认为无关紧要的部分,仍然有600多行,如果有时间我还会在整理这个案例,希望能只保留关键流程代码,两三百行不能再多了。

已经读到这里的用户实属难得,如果遇到困难,请从百度网盘下载(密码:lzjg)直接下载final.py文件使用。请注意文件读写权限,每次运行前请删除saved_model文件夹。


探索人工智能的新边界

如果您发现文章错误,请不吝留言指正;
如果您觉得有用,请点喜欢;
如果您觉得很有用,感谢转发~


END

你可能感兴趣的:(Tensorflow-花分类-图像再训练-part-4-整理翻译)