前几篇博客写了如何处理数据,如何把用自己的数据训练VGG-16,如何把训练好的模型保存。而在实际应用中,并不是所有的操作都是为了分类的,有时候需要提取图像的特征,那么怎么利用已经保存的模型提取特征呢?
“桃叶儿尖上尖,柳叶儿就遮满了天”
测试数据转换成tfrecords,教程:点击打开链接
保存训练好的VGG-16模型,教程:点击打开链接
首先把测试数据转换成tfrecords,然后读取出来,代码和前面博客写的一致:
#读取文件
def read_and_decode(filename,batch_size):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [300, 300, 3]) #图像归一化大小
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #图像减去均值处理
label = tf.cast(features['label'], tf.int32)
#特殊处理
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size= batch_size,
num_threads=64,
capacity=2000,
min_after_dequeue=1500)
return img_batch, tf.reshape(label_batch,[batch_size])
最核心的部分是使用saver类中的restore方法,核心代码如下:
saver = tf.train.import_meta_graph("model/checkpoint/model.ckpt.meta") #注意路径
saver.restore(sess, "./model/checkpoint/model.ckpt") #保存模型的路径
利用的是graph.get_tensor_by_name(“名字”),则首先获取模型中占位符,然后将测试数据传进去,这是最核心的地方,想要提取特征也是通过名字获取张量,比如要提取fc7的特征,则fc7_features=graph.get_tensor_by_name("fc7:0") 。核心代码如下:
graph = tf.get_default_graph() #获取恢复模型的图模型
x_holder = graph.get_tensor_by_name("x_holder:0") # 获取占位符
fc7_features=graph.get_tensor_by_name("fc7:0") #获取要提取的特征,用该层的名字
keep_prob=graph.get_tensor_by_name("keep_prob:0") #同上
# 通过张量的名称来获取张量
print(sess.run(fc7_features,feed_dict={x_holder:image,keep_prob:dropout})) #给占位符重新赋值,则可以提取输入图像的特征
整个过程,博主用了好几天的时间才调通,中间的心酸历程就不多说了,直接放完整的提取特征代码吧,如果想用保存的模型做分类,而不是提特征,则举一反三,我觉得并不难,修改一下即可:
完整代码:
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 2 17:12:00 2018
@author: Heroin 高永标,upc
"""
import tensorflow as tf
#读取文件
def read_and_decode(filename,batch_size):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [300, 300, 3]) #图像归一化大小
# img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #图像减去均值处理
label = tf.cast(features['label'], tf.int32)
#特殊处理
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size= batch_size,
num_threads=64,
capacity=2000,
min_after_dequeue=1500)
return img_batch, tf.reshape(label_batch,[batch_size])
batch_size=4
dropout=1.0
tfrecords_file = 'train.tfrecords' #保存的测试数据
BATCH_SIZE = 4
image_batch, label_batch = read_and_decode(tfrecords_file,BATCH_SIZE)
#print(image_batch)
#sess=tf.InteractiveSession()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess,coord = coord)
image,label=sess.run([image_batch,label_batch])
saver = tf.train.import_meta_graph("model/checkpoint/model.ckpt.meta") #保存的模型路径
saver.restore(sess, "./model/checkpoint/model.ckpt")
graph = tf.get_default_graph()
x_holder = graph.get_tensor_by_name("x_holder:0") # 获取占位符
fc7_features=graph.get_tensor_by_name("fc7:0") #获取要提取的特征,用名字
keep_prob=graph.get_tensor_by_name("keep_prob:0")
# 通过张量的名称来获取张量
print(sess.run(fc7_features,feed_dict={x_holder:image,keep_prob:dropout})) #给占位符重新赋值
sess.close()