只有几串代码,快速实现,但是原理没有详细说明。
使用tensorflow-gpu 2以上版本:
BATCH_SIZE = 2
train_dir = "C:\\Users\Desktop\泸州老窖精品头曲组合装\\"
train_tfrecord = "C:\\Users\Desktop\\train.tfrecords"
dataset_to_tfrecord(dataset_dir=train_dir, tfrecord_name=train_tfrecord)
自定义函数如下:
def dataset_to_tfrecord(dataset_dir, tfrecord_name):
image_paths, image_labels = get_images_and_labels(dataset_dir)
image_paths_and_labels_dict = {}
for i in range(len(image_paths)):
image_paths_and_labels_dict[image_paths[i]] = image_labels[i]
# shuffle the dict
image_paths_and_labels_dict = shuffle_dict(image_paths_and_labels_dict) # 打乱数据
with tf.io.TFRecordWriter(path=tfrecord_name) as writer:
for image_path, label in image_paths_and_labels_dict.items():
print("Writing to tfrecord: {}".format(image_path))
image_string = open(image_path, 'rb').read()
tf_example = image_example(image_string, label)
writer.write(tf_example.SerializeToString())
获取文件夹内数据,文件夹类似于这样,train文件夹下有很多子文件夹,每个子文件夹分别代表一类。image_paths返回的是train下所有图片的路径,image_labels返回的是[0, 1, 2, ...]这样的list。
def get_images_and_labels(data_root_dir):
# get all images' paths (format: string)
data_root = pathlib.Path(data_root_dir)
all_image_path = [str(path) for path in list(data_root.glob('*/*'))]
# get labels' names
label_names = sorted(item.name for item in data_root.glob('*/'))
# dict: {label : index}
label_to_index = dict((label, index) for index, label in enumerate(label_names))
# get all images' labels
all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path]
return all_image_path, all_image_label
def shuffle_dict(original_dict):
keys = []
shuffled_dict = {}
for k in original_dict.keys():
keys.append(k)
random.shuffle(keys)
for item in keys:
shuffled_dict[item] = original_dict[item]
return shuffled_dict
将数据转化为tf.train.Example格式
def _int64_feature(value):
# Returns an int64_list from a bool / enum / int / uint.
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
# Returns a bytes_list from a string / byte.
if isinstance(value, type(tf.constant(0.))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def image_example(image_string, label):
feature = {
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string)
}
return tf.train.Example(features=tf.train.Features(feature=feature))