CIFAR10 二进制数据读取

 

1 分析

  • 构造文件队列
  • 读取二进制数据并进行解码
  • 处理图片数据形状以及数据类型,批处理返回
  • 开启会话线程运行

2 代码

  • 定义CIFAR类,设置图片相关的属性
class CifarRead(object):
    """
    二进制文件的读取,tfrecords存储读取
    """

    def __init__(self):
        # 定义一些图片的属性
        self.height = 32
        self.width = 32
        self.channel = 3

        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes
  • 实现读取数据方法bytes_read(self, file_list)

    • 构造文件队列
    # 1、构造文件队列
    file_queue = tf.train.string_input_producer(file_list)
    
    • tf.FixedLengthRecordReader(bytes)读取
    # 2、使用tf.FixedLengthRecordReader(bytes)读取
    # 默认必须指定读取一个样本
    reader = tf.FixedLengthRecordReader(self.all_bytes)
    
    _, value = reader.read(file_queue)
    
    • 进行解码操作
    # 3、解码操作
    # (?, )   (3073, ) = label(1, ) + feature(3072, )
    label_image = tf.decode_raw(value, tf.uint8)
    # 为了训练方便,一般会把特征值和目标值分开处理
    print(label_image)
    
    • 将数据的标签和图片进行分割
    # 使用tf.slice进行切片
    label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
    
    image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
    
    print(label, image)
    
    • 处理数据的形状,并且进行批处理
    # 处理类型和图片数据的形状
    # 图片形状
    # reshape (3072, )----[channel, height, width]
    # transpose [channel, height, width] --->[height, width, channel]
    depth_major = tf.reshape(image, [self.channel, self.height, self.width])
    print(depth_major)
    
    image_reshape = tf.transpose(depth_major, [1, 2, 0])
    
    print(image_reshape)
    
    # 4、批处理
    image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
    

你可能感兴趣的:(CIFAR10 二进制数据读取)