将数据转为tfrecord格式

假设emo文件夹下,有1,2,3,4等文件夹,每个文件夹代表一个类别

1 import tensorflow as tf
  2 from PIL import Image
  3 from glob import glob
  4 import os
  5 import progressbar
  6 import time
  7 
  8 
  9 class TFRecord():
 10     def __init__(self, path=None, tfrecord_file=None):
 11         self.path = path
 12         self.tfrecord_file = tfrecord_file
 13 
 14     def _convert_image(self, idx, img_path, is_training=True):
 15         label = idx
 16 
 17         with tf.gfile.FastGFile(img_path, 'rb') as fid:
 18             img_str = fid.read()
 19 
 20         # img_data = Image.open(img_path)
 21         # img_data = img_data.resize((224, 224))
 22         # img_str = img_data.tobytes()
 23 
 24         file_name = img_path
 25 
 26         if is_training:
 27             feature_key_value_pair = {
 28                 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList(
 29                     value=[file_name.encode()])),
 30                 'img': tf.train.Feature(bytes_list=tf.train.BytesList(
 31                     value=[img_str])),
 32                 'label': tf.train.Feature(int64_list=tf.train.Int64List(
 33                     value=[label]))
 34             }
 35         else:
 36             feature_key_value_pair = {
 37                 'file_name': tf.train.Feature(bytes_list=tf.train.BytesList(
 38                     value=[file_name.encode()])),
 39                 'img': tf.train.Feature(bytes_list=tf.train.BytesList(
 40                     value=[img_str])),
 41                 'label': tf.train.Feature(int64_list=tf.train.Int64List(
 42                     value=[-1]))
 43             }
 44 
 45         feature = tf.train.Features(feature=feature_key_value_pair)
 46         example = tf.train.Example(features=feature)
 47         return example
 48 
 49     def convert_img_folder(self):
 50 
 51         folder_path = self.path
 52         tfrecord_path = self.tfrecord_file
 53         img_paths = []
 54         for file in os.listdir(folder_path):
 55             for img_path in os.listdir(os.path.join(folder_path, file)):
 56                 img_paths.append(os.path.join(folder_path, file, img_path))
 57 
 58 
 59         with tf.python_io.TFRecordWriter(tfrecord_path) as tfwrite:
 60             widgets = ['[INFO] write image to tfrecord: ', progressbar.Percentage(), " ",
 61                        progressbar.Bar(), " ", progressbar.ETA()]
 62             pbar = progressbar.ProgressBar(maxval=len(img_paths), widgets=widgets).start()
 63 
 64             cate = [folder_path + '/' + x for x in os.listdir(folder_path) if
 65                     os.path.isdir(folder_path + '/' + x)]
 66 
 67             i = 0
 68             for idx, folder in enumerate(cate):
 69                 for img_path in glob(folder + '/*.jpg'):
 70                     example = self._convert_image(idx, img_path)
 71                     tfwrite.write(example.SerializeToString())
 72                     pbar.update(i)
 73                     i += 1
 74 
 75             pbar.finish()
 76 
 77     def _extract_fn(self, tfrecord):
 78         feautres = {
 79             'file_name': tf.FixedLenFeature([], tf.string),
 80             'img': tf.FixedLenFeature([], tf.string),
 81             'label': tf.FixedLenFeature([], tf.int64)
 82         }
 83         sample = tf.parse_single_example(tfrecord, feautres)
 84         img = tf.image.decode_jpeg(sample['img'])
 85         img = tf.image.resize_images(img, (224, 224), method=1)
 86         label = sample['label']
 87         file_name = sample['file_name']
 88         return [img, label, file_name]
 89 
 90     def extract_image(self, shuffle_size, batch_size):
 91         dataset = tf.data.TFRecordDataset([self.tfrecord_file])
 92         dataset = dataset.map(self._extract_fn)
 93         dataset = dataset.shuffle(shuffle_size).batch(batch_size)
 94         print("---------", type(dataset))
 95         return dataset
 96 
 97 
 98 
 99 
100 if __name__=='__main__':
101 
102     # start = time.time()
103     # t = GenerateTFRecord('/')
104     # t.convert_img_folder('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord')
105     # print("Took %f seconds." % (time.time() - start))
106 
107     t =TFRecord('/media/xia/Data/emo', '/media/xia/Data/emo.tfrecord')
108     t.convert_img_folder()
109     dataset = t.extract_image(100, 64)
110     for(batch, batch_data) in enumerate(dataset):
111         data, label, _ = batch_data
112         print(label)
113         print(data.shape)

你可能感兴趣的:(androidflutter)