生成和解析TFRecords文件

最近接触到制作数据集的知识,了解到TFRecords文件。记录笔记如下:

目录:

1.TFrecords文件介绍

2.将数据写入到TFRecords文件

3. 从TFRecords文件中读取数据***


一、TFRecords文件介绍

  1. 文件名:TFRecords文件是以 ** .tfrecords ** 为后缀
  2. 功能: 转化为TFRecords是为了将图片、标记等数据保存在本地(持久化),理解:就像文本文件一样 。
  3. 优点:
    a.可以自定义数据集的存储方式,方便在训练/测试神经网络时数据的读取。
    b.可以对大量的数据进行持久化。
    c.暂不清楚(后续补充)
  4. 缺点: 暂不清楚(后续补充)

二、写入到TFRecords文件

def productTFRs(save_tfrcordes_path):
    '''
    函数功能:将images和labels以TFReordes的形式存储
    参数:是TFRecordes文件的存储路径
    注意:很多地方必须要中括号"[]"
    '''
    
    print("正在将图片保存为TfRecords格式...")
    
    # 1.新建一个writer
    writer = tf.python_io.TFRecordWriter(path=save_tfrcordes_path)
    # getDatas()函数返回的是本地images和labels,且都是以二进制的形式存在
    images, labels = getDatas()
    process = 0 #打印进度
    
    # 2.for循环遍历每张图片和标签
    for image, label in zip(images, labels):
        process += 1
        
        # 3.把每张图片和标签封装到example中。如有不明,参考补充知识/code01.py
        feature_dict = {'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
                        'label_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label]))}
        features = tf.train.Features(feature=feature_dict)
        example = tf.train.Example(features=features)
        # print("____example____:\n",example,) # 观察example的形式
        
        # 4.把example进行序列化。序列化:从内存变成可存储或可传输的过程
        writer.write(example.SerializeToString())
        
        print("\r已经保存了{}张,已完成进度{}%!".format(process,process/len(images)*100),end='',flush=True)

    print("\n保存完成!!!")
    
    # 5.关闭writer
    writer.close()

三、读取TFRecords文件

def parseTFRs(save_tfrcordes_path,batch_size):
    '''
    函数功能:读取(解析)TFRecords文件
    参数:
    	1.TFRecords的存储路径
    	2.batch_size 是训练网络神经的轮数
    注意:很多地方必须要中括号"[]"
    返回值:image,label
    返回值说明: 1.image和label都是numpy形式
              2.image.shape = [batch_size,IMAGE_HEIGHT,IMAGE_WIDTH,IMAGE_CHANNELS]
              3.label.shape = [batch_size,ALL_KINDS]
    '''
    
    # 1.创建文件名队列
    filename_queue = tf.train.string_input_producer([save_tfrcordes_path])

    # 2.新建reader对象
    reader = tf.TFRecordReader()

    # 3.将读取的文件赋值给serialized_example
    # serialized_example是无法直接查看的,需要去按照特征进行解析。
    _,serialized_example = reader.read(filename_queue)

    # 4.根据已有的特征(feature),解析serialized_example
    features = tf.parse_single_example(serialized_example,features={
        'image_raw':tf.FixedLenFeature([],tf.string),
        'label_raw':tf.FixedLenFeature([],tf.string)
    })

    # 5.提取特征,二进制就需要解码,否则其他两种形式(tf.train.Int64List/tf.train.FloatList,具体参见code1)直接以访问字典的方法访问
    image = tf.decode_raw(features['image_raw'],tf.uint8)   # 需要解码,因为是二进制的形式
    image = tf.reshape(image,[IMAGE_HEIGHT,IMAGE_WIDTH,IMAGE_CHANNELS])
    label = tf.decode_raw(features['label_raw'],tf.int64) # 调试发现:上面解析serialized_example中,'label_raw':tf.FixedLenFeature([],tf.int64) 就不行,为啥?
    label = tf.reshape(label,[ALL_KINDS])

    # 6.将样本包装成一个一个的batch
    # tf.train.shuffle_batch() 将队列中数据打乱后再读取出来.
   	'''
   	部分重要参数如下:
	   	batch_size:一次获取数据的数量
	    capacity:队列中元素的最大数量.
	    min_after_dequeue:出队后队列中元素的最小数量,用于确保元素的混合级别.
	并且各参数还有大小关系:capacity>min_after_dequeue, capacity>batch_size
    '''
    
    capacity = 2*batch_size
    min_after_dequeue = batch_size
    image,label = tf.train.shuffle_batch([image,label],batch_size=batch_size,capacity=capacity,min_after_dequeue=min_after_dequeue)
    # print(image.shape)      # (3, 40, 32)
    # print(label.shape)      # (3, 34)

    # 查看读取的效果(查看读取效果时一定要把线程协调器开启,否则(短时间?)看不见效果)
    # with tf.Session() as sess:
    #     tf.global_variables_initializer().run()
    #     coord = tf.train.Coordinator() #实例化协调器
    #     threads = tf.train.start_queue_runners(sess=sess,coord=coord) # 开起线程
    #     for i in range(10):
    #         print(sess.run([image,label]),end='--'*50)
    #     coord.request_stop()
    #     coord.join(threads=threads)

    return image,label

你可能感兴趣的:(tensorflow笔记)