AI challenger 场景分类 生成tfrecord文件

与AI challenger 场景分类(1) 生成tfrecord文件 不同,这里我们生成验证集的tfrecord文件,另外将图片的id (string类型)也存入tfrecord, 方便后续使用。

用时:~5 min
原图大小:463.9 M
tfrecord文件大小:8.7 G

# -*- coding: utf-8 -*-
"""
Created on Thu Sep  7 19:25:38 2017

@author: wayne

convert the validation set to tfrecord, with image id

https://stackoverflow.com/questions/42444468/tensorflow-is-there-a-way-to-locate-the-filenames-of-images-encoded-into-tfreco

http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html

"""


import json
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import datetime



record_PATH = '../ai_challenger_scene_validation_20170908/'   # 目标文件夹
tfrecord_file = record_PATH + 'val.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_file)


def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def get_image_binary(filename):
    """ You can read in the image using tensorflow too, but it's a drag
        since you have to create graphs. It's much easier using Pillow and NumPy
    """
    image = Image.open(filename)
    image = np.asarray(image, np.uint8)
    shape = np.array(image.shape, np.int32)
    return shape, image.tobytes() # convert image to raw data bytes in the array.

def write_to_tfrecord(label, shape, binary_image, tfrecord_file, image_id):
    """ This example is to write a sample to TFRecord file. If you want to write
    more samples, just use a loop.
    """
    # write label, shape, and image content to the TFRecord file
    example = tf.train.Example(features=tf.train.Features(feature={
                'label': _int64_feature(label),
                'h': _int64_feature(shape[0]),
                'w': _int64_feature(shape[1]),
                'c': _int64_feature(shape[2]),
                'image': _bytes_feature(binary_image),
                'image_id': _bytes_feature(image_id.encode()) # string to bytes
                }))
    writer.write(example.SerializeToString())


def write_tfrecord(label, image_file, tfrecord_file, image_id):
    shape, binary_image = get_image_binary(image_file)
    write_to_tfrecord(label, shape, binary_image, tfrecord_file, image_id)

with open('../ai_challenger_scene_validation_20170908/scene_validation_annotations_20170908.json', 'r') as f: #label文件
    label_raw = json.load(f)

def file_name2(file_dir):   #特定类型的文件
    L=[]   
    image = []
    for root, dirs, files in os.walk(file_dir):  
        for file in files:  
            if os.path.splitext(file)[1] == '.jpg':   
                L.append(os.path.join(root, file))
                image.append(file)
    return L, image

path, image_id = file_name2('../ai_challenger_scene_validation_20170908/scene_validation_images_20170908') #图片目录


'''
存入tfrecords
'''

label = {}
for item in label_raw:
    label[item['image_id']] = int(item['label_id'])


starttime = datetime.datetime.now()
#long running

num = len(path)

for i in range(num):
    write_tfrecord(label[image_id[i]], path[i], tfrecord_file, image_id[i])
    if i%1000==0:
        print(i)
writer.close()

endtime = datetime.datetime.now()
print (endtime - starttime)

对于测试集,需要先生成一个假的图像id文件,可使用如下代码:

import json
import os


with open('../ai_challenger_scene_validation_20170908/scene_validation_annotations_20170908.json', 'r') as f: #label文件
    label_raw_val = json.load(f)


label_raw = []

def file_name2(file_dir):   #特定类型的文件
    L=[]   
    image = []
    for root, dirs, files in os.walk(file_dir):  
        for file in files:  
            if os.path.splitext(file)[1] == '.jpg':   
                L.append(os.path.join(root, file))
                image.append(file)
                label_raw.append({'image_id':file, 'label_id':1})
    return L, image

path, image_id = file_name2('/home/wayne/python/kaggle/Ai_challenger/classification/ai_challenger_scene_test_a_20170922/scene_test_a_images_20170922') #图片目录


with open('scene_test_annotations.json', 'w') as f:
    json.dump(label_raw, f)

你可能感兴趣的:(TensorFlow)