深度学习tensorflow实战笔记(4)利用保存的VGG-16CNN网络模型提取特征

    前几篇博客写了如何处理数据,如何把用自己的数据训练VGG-16,如何把训练好的模型保存。而在实际应用中,并不是所有的操作都是为了分类的,有时候需要提取图像的特征,那么怎么利用已经保存的模型提取特征呢?

   “桃叶儿尖上尖,柳叶儿就遮满了天”

    测试数据转换成tfrecords,教程:点击打开链接

    保存训练好的VGG-16模型,教程:点击打开链接

1、读取测试数据

      首先把测试数据转换成tfrecords,然后读取出来,代码和前面博客写的一致:      

#读取文件
def read_and_decode(filename,batch_size):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [300, 300, 3])                #图像归一化大小
   # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5   #图像减去均值处理
    label = tf.cast(features['label'], tf.int32)        

    #特殊处理

    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                    batch_size= batch_size,
                                                    num_threads=64,
                                                    capacity=2000,
                                                    min_after_dequeue=1500)
    return img_batch, tf.reshape(label_batch,[batch_size])

2、调取保存的训练好的VGG-16模型

    最核心的部分是使用saver类中的restore方法,核心代码如下:

    saver = tf.train.import_meta_graph("model/checkpoint/model.ckpt.meta")    #注意路径

    saver.restore(sess, "./model/checkpoint/model.ckpt")     #保存模型的路径

3、把测试数据传进去模型提取特征

     利用的是graph.get_tensor_by_name(“名字”),则首先获取模型中占位符,然后将测试数据传进去,这是最核心的地方,想要提取特征也是通过名字获取张量,比如要提取fc7的特征,则fc7_features=graph.get_tensor_by_name("fc7:0") 。核心代码如下:

    graph = tf.get_default_graph()  #获取恢复模型的图模型
    x_holder = graph.get_tensor_by_name("x_holder:0")  #    获取占位符
    fc7_features=graph.get_tensor_by_name("fc7:0")     #获取要提取的特征,用该层的名字
    keep_prob=graph.get_tensor_by_name("keep_prob:0")   #同上
    # 通过张量的名称来获取张量

    print(sess.run(fc7_features,feed_dict={x_holder:image,keep_prob:dropout}))  #给占位符重新赋值,则可以提取输入图像的特征
    

4、完整的代码

     整个过程,博主用了好几天的时间才调通,中间的心酸历程就不多说了,直接放完整的提取特征代码吧,如果想用保存的模型做分类,而不是提特征,则举一反三,我觉得并不难,修改一下即可:

完整代码:

# -*- coding: utf-8 -*-
"""
Created on Mon Apr  2 17:12:00 2018

@author: Heroin 高永标,upc
"""

import tensorflow as tf

#读取文件
def read_and_decode(filename,batch_size):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [300, 300, 3])                #图像归一化大小
   # img = tf.cast(img, tf.float32) * (1. / 255) - 0.5   #图像减去均值处理
    label = tf.cast(features['label'], tf.int32)        

    #特殊处理

    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                    batch_size= batch_size,
                                                    num_threads=64,
                                                    capacity=2000,
                                                    min_after_dequeue=1500)
    return img_batch, tf.reshape(label_batch,[batch_size])

batch_size=4
dropout=1.0

tfrecords_file = 'train.tfrecords'     #保存的测试数据
BATCH_SIZE = 4
image_batch, label_batch = read_and_decode(tfrecords_file,BATCH_SIZE)
#print(image_batch)


#sess=tf.InteractiveSession()
with tf.Session() as sess:
    coord = tf.train.Coordinator()  
    threads = tf.train.start_queue_runners(sess = sess,coord = coord)
    image,label=sess.run([image_batch,label_batch])  
    saver = tf.train.import_meta_graph("model/checkpoint/model.ckpt.meta")     #保存的模型路径
    saver.restore(sess, "./model/checkpoint/model.ckpt")
    graph = tf.get_default_graph()  
    x_holder = graph.get_tensor_by_name("x_holder:0")  #    获取占位符
    fc7_features=graph.get_tensor_by_name("fc7:0")     #获取要提取的特征,用名字
    keep_prob=graph.get_tensor_by_name("keep_prob:0")
    # 通过张量的名称来获取张量

    print(sess.run(fc7_features,feed_dict={x_holder:image,keep_prob:dropout}))  #给占位符重新赋值
    
    sess.close()    
    

你可能感兴趣的:(tensorflow-深度学习)