实验所需要的环境:
tensorflow-0.10
python-opencv
Image ubuntu14.4
本实验的数据集是点击打开链接微云
其中包括9种手势,部分形式见下面的图,如果连接失效,可以向我索要。
单通道的图片,但是写代码的时候差异不大。
文件的存放位置如下:
制作TFRecords时,
# coding: UTF-8
import os
import tensorflow as tf
from PIL import Image
import cv2
import numpy as np
image_size = 28 # tfrecords size
tf_save_path = "my_tfrecords/gesture_train.tfrecords"
labels_path = "labels.txt"
image_path = "picture_b_image/"
tf_out_path = "tf_out/"
print "\n/**************************/"
print "\n gesture produce .tfrecords data~~~~"
f = open(labels_path)
class_id_cnt = 0
classes_read = []
print "\n 读取样本的分类号:"
while True:
line = f.readline()
if line:
class_id_cnt = class_id_cnt + 1
line = line.strip()
classes_read.append(line)
print class_id_cnt, ")", "-->", classes_read[class_id_cnt-1]
else:
break
f.close()
print "\n"
# 生成tfrecords文件
writer = tf.python_io.TFRecordWriter(tf_save_path)
picture_cnt = 0
for index, name in enumerate(classes_read):
class_path = image_path+name+'/'""
print "第 ", index,"类 开始转换~~~~"
for img_name in os.listdir(class_path):
img_path = class_path + img_name
if(picture_cnt%60==0):
img_cv = cv2.imread(img_path)
cv2.namedWindow("image_tfrecords", 0)
cv2.imshow("image_tfrecords", img_cv)
cv2.waitKey(1)
img = Image.open(img_path)
img = img.resize((image_size, image_size))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
picture_cnt = picture_cnt + 1
writer.close()
cv2.destroyAllWindows()
print "TFrecords文件生成:样本总数为", picture_cnt
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label':tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
}
)
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [image_size, image_size, 1])
label = tf.cast(features['label'], tf.int32)
return img, label
image2, labels2 = read_and_decode(tf_save_path)
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
coor = tf.train.Coordinator() #create a thread
threads = tf.train.start_queue_runners(coord=coor)
Classes_cnt = np.zeros([class_id_cnt], np.int32)
for i in range(picture_cnt):
example, class_num = sess.run([image2, labels2])
if(i%30==0):
cv2.namedWindow("image_out", 0)
cv2.imshow("image_out", example)
cv2.waitKey(1)
out_file = tf_out_path+str(i)+'_''Label_'+str(class_num)+'.jpg'
cv2.imwrite(out_file, example)
Classes_cnt[class_num] = Classes_cnt[class_num] + 1
coor.request_stop()
coor.join(threads)
for i in range(class_id_cnt):
print("分类号", i, " = ", Classes_cnt[i], " 个样本")
cv2.destroyAllWindows() # 销毁opencv显示窗口
sess.close()
print("\n3)TfRecords测试 转换成功 ! ")
print("well done!")
下面对代码详解:
代码开始是将存放的路径都以字符串的形式写出,一边后面引用的时候方便。
f = open(labels_path)
读取文件,将文件中的分类都读取到代码中,取出类放到classes_read中。
writer = tf.python_io.TFRecordWriter(tf_save_path)生成tfrecords文件,其中利用cv2进行图像的显示和修改。
img_raw = img.tobytes() example = tf.train.Example(features=tf.train.Features(feature={ "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString())将图片信息都写道tfrecords中。
下面就是读取tfrecords文件中的内容。
def read_and_decode(filename): filename_queue = tf.train.string_input_producer([filename]) reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label':tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string), } ) img = tf.decode_raw(features['img_raw'], tf.uint8) img = tf.reshape(img, [image_size, image_size, 1]) label = tf.cast(features['label'], tf.int32) return img, label image2, labels2 = read_and_decode(tf_save_path)这是读取的代码,然后在会话中多线程调用这个函数。
coor = tf.train.Coordinator() #create a thread threads = tf.train.start_queue_runners(coord=coor)