TensorFlow学习8:制作数据集

将所有图片生成一个二进制数据集文件的过程

示例代码

#可以将图片和标签制作成二进制文件,读取二进制文件进行数据读取,会提高内存利用率。
#训练数据的特征用键值对的形式表示
def write_tfRecord(tfRecordName,image_path,label_path):
    #创建写入
    writer=tf.python_io.TFRecordWriter(tfRecordName)
    num_pic=0
    f=open(label_path,'r')
    contents=f.readlines()
    f.close()
    #遍历每张图和标签
    for content in contents:
        value=content.split()
        img_path=image_path+value[0]
        img=Image.open(img_path)
        img_raw=img.tobytes()
        labels=[0]*10
        lables[int(value[1])]=1
        example=tf.train.Example(features=tf.train.features(feature={
            'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            'label':tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
            }))
        writer.write(example.SerializeToString())
        num_pic+=1
        #序列化
        print("the number of picture:",num_pic)
    writer.close()


def generate_tfRecord():
    isExists=os.path.exists(data_path)
    if not isExists:
        os.makedirs(data_path)
        print("Created")
    else:
        print("Already Exists")

    write_tfRecord(tfRecord_train,image_train_path,label_train_path)
    write_tfRecord(tfRecord_test,image_test_path,label_test_path)


#解析文件
def read_tfRecord(tfRecord_path):
    #生成一个先入先出的队列
    filename_queue=tf.train.string_input_producer([tfRecord_path])
    reader=tf.TFRecordReader()
    _,serialized_example=reader.read(filename_queue)
    features=tf.parse_single_example(serialized_example,features={
        'label':tf.FixedLenFeature([10],tf.int64),
        'img_raw':tf.FixedLenFeature([],tf.string)
        })
    img=tf.decode_raw(features['img_raw'],tf.uint8)
    img.set_shape([784])
    img=tf.cast(img,tf.float32)*(1./255)
    label=tf.cast(features['label'],tf.float32)

    return img,lable

def get_tfrecord(num,isTrain=True):
    if isTrain:
        tfRecord_path=tfRecord_path
    else:
        tfRecord_path=tfRecord_test
    img,label=read_tfRecord(tfRecord_path)

    img_batch,label_batch=tf.train.shuffle_batch([img,label],batch_size=num,num_threads=2,capacity=1000,min_after_dequeue=700)

    return img_batch,label_batch

def main():
    generate_tfRecord()

if __name__=='__main__':
    main()




参考:人工智能实践:Tensorflow笔记

你可能感兴趣的:(TensorFlow学习8:制作数据集)