本文以flower分类数据集为例,数据集存储格式为:以类别名命名文件夹,将不同类别图像存储在里面
1.根据下载好的分类数据集进行随机分割,读取,保存
linux系统下输入一下命令:
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar xzf flower_photos.tgz
windows直接输入下面链接下载:
http://download.tensorflow.org/example_images/flower_photos.tgz
def get_dataset_dict(imagedir, train_percentage=8):
rootdir = imagedir
category = [x[1] for x in os.walk(imagedir)][0]
dataset = {}
label = {}
for j, class_name in enumerate(category):
subdir = os.path.join(rootdir, class_name)
imagelist = os.listdir(subdir)
number = len(imagelist)
label[class_name] = j
train_dataset = []
test_dataset = []
for i, image in enumerate(imagelist):
r = random.randint(0, number)
if r < number / 10.0 *train_percentage:
train_dataset.append(image)
else:
test_dataset.append(image)
dataset[class_name] = {
'dir':subdir,
'train':train_dataset,
'test':test_dataset
}
return dataset, label
2.以tfrecords格式保存数据
def create_tfrecord(dataset, label, tfrecord_dir, dataset_type='train',resize=None):
writer = tf.python_io.TFRecordWriter(os.path.join(tfrecord_dir,
dataset_type + '.tfrecords'))
for classname, info in dataset.items():
for imagefile in info[dataset_type]:
example = create_tfrecord_example(classname,
os.path.join(info['dir'], imagefile), label, resize)
writer.write(example.SerializeToString())
writer.close()
def create_tfrecord_example(classname, imagefile, label, resize=None):
pil_image = Image.open(imagefile)
if resize != None:
pil_image = pil_image.resize(resize)
bytes_image = pil_image.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'label': int64_feature(label[classname]),
'image': bytes_feature(bytes_image)
#'format': bytes_feature('jpg')
}))
return example
def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
3.读取tfrecords格式数据
def read_tfrecord(tfrecord_path, resize, batch_size=1):
print('tfrecord:{}'.format(tfrecord_path))
filename_queue = tf.train.string_input_producer([tfrecord_path])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['image'], tf.uint8)
image = tf.reshape(image, resize)
label = tf.cast(features['label'], tf.int64)
img_batch, label_batch = tf.train.shuffle_batch([image, label],
batch_size=batch_size,
capacity=392,
min_after_dequeue=200)
return img_batch, label_batch
4.示例
if __name__ == "__main__":
dataset, label = get_dataset_dict(IMAGEDIR, 8)
j_d = json.dumps(dataset)
j_l = json.dumps(label)
with open('j_d.json','w',encoding='utf-8') as f:
f.write(j_d)
f.close()
with open('j_l.json','w',encoding='utf-8') as f:
f.write(j_l)
f.close()
create_tfrecord(dataset, label, TFRECORDDIR, 'train', (256, 256))
create_tfrecord(dataset, label, TFRECORDDIR, 'test', (256, 256))