tensorflow读取分类数据集,并随机将其分割为训练集和测试集,以tfrecords形式保存

本文以flower分类数据集为例,数据集存储格式为:以类别名命名文件夹,将不同类别图像存储在里面

1.根据下载好的分类数据集进行随机分割,读取,保存

linux系统下输入一下命令:

wget http://download.tensorflow.org/example_images/flower_photos.tgz

tar xzf flower_photos.tgz

windows直接输入下面链接下载:

http://download.tensorflow.org/example_images/flower_photos.tgz

def get_dataset_dict(imagedir, train_percentage=8):
    rootdir = imagedir
    category = [x[1] for x in os.walk(imagedir)][0]
    dataset = {}
    label = {}
    for j, class_name in enumerate(category):
        subdir = os.path.join(rootdir, class_name)
        imagelist = os.listdir(subdir)
        number = len(imagelist)
        label[class_name] = j
        train_dataset = []
        test_dataset = []
        for i, image in enumerate(imagelist):
            r = random.randint(0, number)
            if r < number / 10.0 *train_percentage:
                train_dataset.append(image)
            else:
                test_dataset.append(image)
        dataset[class_name] = {
            'dir':subdir, 
            'train':train_dataset,
            'test':test_dataset
        }
    return dataset, label

2.以tfrecords格式保存数据

def create_tfrecord(dataset, label, tfrecord_dir, dataset_type='train',resize=None):
    writer = tf.python_io.TFRecordWriter(os.path.join(tfrecord_dir, 
                dataset_type + '.tfrecords'))
    for classname, info in dataset.items():
        for imagefile in info[dataset_type]:
            example = create_tfrecord_example(classname, 
                      os.path.join(info['dir'], imagefile), label, resize)
            writer.write(example.SerializeToString())
    writer.close()

def create_tfrecord_example(classname, imagefile, label, resize=None):
    pil_image = Image.open(imagefile)
    if resize != None:
        pil_image = pil_image.resize(resize)
    bytes_image = pil_image.tobytes()
    example = tf.train.Example(features=tf.train.Features(feature={
        'label': int64_feature(label[classname]), 
        'image': bytes_feature(bytes_image)
        #'format': bytes_feature('jpg')
    }))    
    return example

def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

3.读取tfrecords格式数据

def read_tfrecord(tfrecord_path, resize, batch_size=1):
    print('tfrecord:{}'.format(tfrecord_path))
    filename_queue = tf.train.string_input_producer([tfrecord_path])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(serialized_example,
                                   features={
                                   'label': tf.FixedLenFeature([], tf.int64), 
                                   'image': tf.FixedLenFeature([], tf.string),
                                   })
    image = tf.decode_raw(features['image'], tf.uint8)
    image = tf.reshape(image, resize)
    label = tf.cast(features['label'], tf.int64)
    img_batch, label_batch = tf.train.shuffle_batch([image, label],
                    batch_size=batch_size,
                    capacity=392,
                    min_after_dequeue=200)
    return img_batch, label_batch

4.示例

if __name__ == "__main__":
    dataset, label = get_dataset_dict(IMAGEDIR, 8)
    j_d = json.dumps(dataset)
    j_l = json.dumps(label)
    with open('j_d.json','w',encoding='utf-8') as f:
        f.write(j_d)
        f.close()
    with open('j_l.json','w',encoding='utf-8') as f:
        f.write(j_l)
        f.close()
    create_tfrecord(dataset, label, TFRECORDDIR, 'train', (256, 256))
    create_tfrecord(dataset, label, TFRECORDDIR, 'test', (256, 256))

 

你可能感兴趣的:(数据集,tensorflow,读取分类数据集,随机分割数据集为测试集和训练集,json格式保存和读取数据)