使用tensorflow将图像数据集制作成tfrecords格式(代码)

只有几串代码,快速实现,但是原理没有详细说明。

使用tensorflow-gpu 2以上版本:

BATCH_SIZE = 2
train_dir = "C:\\Users\Desktop\泸州老窖精品头曲组合装\\"
train_tfrecord = "C:\\Users\Desktop\\train.tfrecords"
dataset_to_tfrecord(dataset_dir=train_dir, tfrecord_name=train_tfrecord)

自定义函数如下: 

def dataset_to_tfrecord(dataset_dir, tfrecord_name):
    image_paths, image_labels = get_images_and_labels(dataset_dir)
    image_paths_and_labels_dict = {}
    for i in range(len(image_paths)):
        image_paths_and_labels_dict[image_paths[i]] = image_labels[i]
    # shuffle the dict
    image_paths_and_labels_dict = shuffle_dict(image_paths_and_labels_dict) # 打乱数据
    with tf.io.TFRecordWriter(path=tfrecord_name) as writer:
        for image_path, label in image_paths_and_labels_dict.items():
            print("Writing to tfrecord: {}".format(image_path))
            image_string = open(image_path, 'rb').read()
            tf_example = image_example(image_string, label)
            writer.write(tf_example.SerializeToString())

获取文件夹内数据,文件夹类似于这样,train文件夹下有很多子文件夹,每个子文件夹分别代表一类。image_paths返回的是train下所有图片的路径,image_labels返回的是[0, 1, 2, ...]这样的list。

使用tensorflow将图像数据集制作成tfrecords格式(代码)_第1张图片

 

def get_images_and_labels(data_root_dir):
    # get all images' paths (format: string)
    data_root = pathlib.Path(data_root_dir)
    all_image_path = [str(path) for path in list(data_root.glob('*/*'))]
    # get labels' names
    label_names = sorted(item.name for item in data_root.glob('*/'))
    # dict: {label : index}
    label_to_index = dict((label, index) for index, label in enumerate(label_names))
    # get all images' labels
    all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path]
    return all_image_path, all_image_label

def shuffle_dict(original_dict):
    keys = []
    shuffled_dict = {}
    for k in original_dict.keys():
        keys.append(k)
    random.shuffle(keys)
    for item in keys:
        shuffled_dict[item] = original_dict[item]
    return shuffled_dict

将数据转化为tf.train.Example格式

def _int64_feature(value):
    # Returns an int64_list from a bool / enum / int / uint.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    # Returns a bytes_list from a string / byte.
    if isinstance(value, type(tf.constant(0.))):
        value = value.numpy()   # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def image_example(image_string, label):
    feature = {
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(image_string)
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))

你可能感兴趣的:(python)