假设emo文件夹下,有1,2,3,4等文件夹,每个文件夹代表一个类别
1 import tensorflow as tf 2 from PIL import Image 3 from glob import glob 4 import os 5 import progressbar 6 import time 7 8 9 class TFRecord(): 10 def __init__(self, path=None, tfrecord_file=None): 11 self.path = path 12 self.tfrecord_file = tfrecord_file 13 14 def _convert_image(self, idx, img_path, is_training=True): 15 label = idx 16 17 with tf.gfile.FastGFile(img_path, 'rb') as fid: 18 img_str = fid.read() 19 20 # img_data = Image.open(img_path) 21 # img_data = img_data.resize((224, 224)) 22 # img_str = img_data.tobytes() 23 24 file_name = img_path 25 26 if is_training: 27 feature_key_value_pair = { 28 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList( 29 value=[file_name.encode()])), 30 'img': tf.train.Feature(bytes_list=tf.train.BytesList( 31 value=[img_str])), 32 'label': tf.train.Feature(int64_list=tf.train.Int64List( 33 value=[label])) 34 } 35 else: 36 feature_key_value_pair = { 37 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList( 38 value=[file_name.encode()])), 39 'img': tf.train.Feature(bytes_list=tf.train.BytesList( 40 value=[img_str])), 41 'label': tf.train.Feature(int64_list=tf.train.Int64List( 42 value=[-1])) 43 } 44 45 feature = tf.train.Features(feature=feature_key_value_pair) 46 example = tf.train.Example(features=feature) 47 return example 48 49 def convert_img_folder(self): 50 51 folder_path = self.path 52 tfrecord_path = self.tfrecord_file 53 img_paths = [] 54 for file in os.listdir(folder_path): 55 for img_path in os.listdir(os.path.join(folder_path, file)): 56 img_paths.append(os.path.join(folder_path, file, img_path)) 57 58 59 with tf.python_io.TFRecordWriter(tfrecord_path) as tfwrite: 60 widgets = ['[INFO] write image to tfrecord: ', progressbar.Percentage(), " ", 61 progressbar.Bar(), " ", progressbar.ETA()] 62 pbar = progressbar.ProgressBar(maxval=len(img_paths), widgets=widgets).start() 63 64 cate = [folder_path + '/' + x for x in os.listdir(folder_path) if 65 os.path.isdir(folder_path + '/' + x)] 66 67 i = 0 68 for idx, folder in enumerate(cate): 69 for img_path in glob(folder + '/*.jpg'): 70 example = self._convert_image(idx, img_path) 71 tfwrite.write(example.SerializeToString()) 72 pbar.update(i) 73 i += 1 74 75 pbar.finish() 76 77 def _extract_fn(self, tfrecord): 78 feautres = { 79 'file_name': tf.FixedLenFeature([], tf.string), 80 'img': tf.FixedLenFeature([], tf.string), 81 'label': tf.FixedLenFeature([], tf.int64) 82 } 83 sample = tf.parse_single_example(tfrecord, feautres) 84 img = tf.image.decode_jpeg(sample['img']) 85 img = tf.image.resize_images(img, (224, 224), method=1) 86 label = sample['label'] 87 file_name = sample['file_name'] 88 return [img, label, file_name] 89 90 def extract_image(self, shuffle_size, batch_size): 91 dataset = tf.data.TFRecordDataset([self.tfrecord_file]) 92 dataset = dataset.map(self._extract_fn) 93 dataset = dataset.shuffle(shuffle_size).batch(batch_size) 94 print("---------", type(dataset)) 95 return dataset 96 97 98 99 100 if __name__=='__main__': 101 102 # start = time.time() 103 # t = GenerateTFRecord('/') 104 # t.convert_img_folder('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord') 105 # print("Took %f seconds." % (time.time() - start)) 106 107 t =TFRecord('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord') 108 t.convert_img_folder() 109 dataset = t.extract_image(100, 64) 110 for(batch, batch_data) in enumerate(dataset): 111 data, label, _ = batch_data 112 print(label) 113 print(data.shape)
ps: tf.enable_eager_execution()
tf.__version__==1.8.0
参考:https://zhuanlan.zhihu.com/p/30751039
https://lonepatient.top/2018/06/01/tensorflow_tfrecord.html
https://zhuanlan.zhihu.com/p/51186668