AI challenger 场景分类(1) 生成tfrecord文件

用时:30 min
原图大小:3.5 G
tfrecord文件大小:65.3 G (amazing! 注意原图是jpg压缩的)

# -*- coding: utf-8 -*-
"""
Created on Thu Sep  7 19:25:38 2017

@author: wayne

http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html

"""


import json
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import datetime



record_PATH = 'ai_challenger_scene_train_20170904/'   # 目标文件夹
tfrecord_file = record_PATH + 'train.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_file)


def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def get_image_binary(filename):
    """ You can read in the image using tensorflow too, but it's a drag
        since you have to create graphs. It's much easier using Pillow and NumPy
    """
    image = Image.open(filename)
    image = np.asarray(image, np.uint8)
    shape = np.array(image.shape, np.int32)
    return shape, image.tobytes() # convert image to raw data bytes in the array.

def write_to_tfrecord(label, shape, binary_image, tfrecord_file):
    """ This example is to write a sample to TFRecord file. If you want to write
    more samples, just use a loop.
    """
    # write label, shape, and image content to the TFRecord file
    example = tf.train.Example(features=tf.train.Features(feature={
                'label': _int64_feature(label),
                'h': _int64_feature(shape[0]),
                'w': _int64_feature(shape[1]),
                'c': _int64_feature(shape[2]),
                'image': _bytes_feature(binary_image)
                }))
    writer.write(example.SerializeToString())


def write_tfrecord(label, image_file, tfrecord_file):
    shape, binary_image = get_image_binary(image_file)
    write_to_tfrecord(label, shape, binary_image, tfrecord_file)

with open('ai_challenger_scene_train_20170904/scene_train_annotations_20170904.json', 'r') as f: #label文件
    label_raw = json.load(f)

def file_name2(file_dir):   #特定类型的文件
    L=[]   
    image = []
    for root, dirs, files in os.walk(file_dir):  
        for file in files:  
            if os.path.splitext(file)[1] == '.jpg':   
                L.append(os.path.join(root, file))
                image.append(file)
    return L, image

path, image = file_name2('ai_challenger_scene_train_20170904/scene_train_images_20170904') #图片目录


'''
存入tfrecords
'''

label = {}
for item in label_raw:
    label[item['image_id']] = int(item['label_id'])


starttime = datetime.datetime.now()
#long running

num = len(path)

for i in range(num):
    write_tfrecord(label[image[i]], path[i], tfrecord_file)
    if i%1000==0:
        print(i)
writer.close()

endtime = datetime.datetime.now()
print (endtime - starttime).seconds


你可能感兴趣的:(TensorFlow)