https://www.cs.toronto.edu/~kriz/cifar.html
二进制版本包含文件data_batch_1.bin,data_batch_2.bin,...,data_batch_5.bin以及test_batch.bin
。这些文件中的每一个格式如下,数据中每个样本包含了特征值和目标值:
<1×标签> <3072×像素>
...
<1×标签> <3072×像素>
第一个字节是第一个图像的标签,它是一个0-9范围内的数字。接下来的3072个字节是图像像素的值。前1024个字节是红色通道值,下1024个绿色,最后1024个蓝色。值以行优先顺序存储,因此前32个字节是图像第一行的红色通道值。 每个文件都包含10000个这样的3073字节的“行”图像,但没有任何分隔行的限制。因此每个文件应该完全是30730000字节长。
class CifarRead(object):
"""
二进制文件的读取,tfrecords存储读取
"""
def __init__(self):
# 定义一些图片的属性
self.height = 32
self.width = 32
self.channel = 3
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channel
self.bytes = self.label_bytes + self.image_bytes
实现读取数据方法bytes_read(self, file_list)
# 1、构造文件队列
file_queue = tf.train.string_input_producer(file_list)
# 2、使用tf.FixedLengthRecordReader(bytes)读取
# 默认必须指定读取一个样本
reader = tf.FixedLengthRecordReader(self.all_bytes)
_, value = reader.read(file_queue)
# 3、解码操作
# (?, ) (3073, ) = label(1, ) + feature(3072, )
label_image = tf.decode_raw(value, tf.uint8)
# 为了训练方便,一般会把特征值和目标值分开处理
print(label_image)
# 使用tf.slice进行切片
label = tf.cast(tf.slice(label_image, [0], [self.label_bytes]), tf.int32)
image = tf.slice(label_image, [self.label_bytes], [self.image_bytes])
print(label, image)
# 处理类型和图片数据的形状
# 图片形状
# reshape (3072, )----[channel, height, width]
# transpose [channel, height, width] --->[height, width, channel]
depth_major = tf.reshape(image, [self.channel, self.height, self.width])
print(depth_major)
image_reshape = tf.transpose(depth_major, [1, 2, 0])
print(image_reshape)
# 4、批处理
image_batch, label_batch = tf.train.batch([image_reshape, label], batch_size=10, num_threads=1, capacity=10)
这里的图片形状设置从1维的排列到3维数据的时候,涉及到NHWC与NCHW的概念:
1)NHWC与NCHW
在读取设置图片形状的时候有两种格式:
设置为 "NHWC" 时,排列顺序为 [batch, height, width, channels];
设置为 "NCHW" 时,排列顺序为 [batch, channels, height, width]。
其中 N 表示这批图像有几张,H 表示图像在竖直方向有多少像素,W 表示水平方向像素数,C 表示通道数。
Tensorflow默认的[height, width, channel]
假设RGB三通道两种格式的区别如下图所示:
1 理解
假设1, 2, 3, 4-红色 5, 6, 7, 8-绿色 9, 10, 11, 12-蓝色
# 1、想要变成:[2 height, 2width, 3channel],但是输出结果不对
In [7]: tf.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [2, 2, 3]).eval()
Out[7]:
array([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]], dtype=int32)
# 2、所以要这样去做
In [8]: tf.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [3, 2, 2]).eval()
Out[8]:
array([[[ 1, 2],
[ 3, 4]],
[[ 5, 6],
[ 7, 8]],
[[ 9, 10],
[11, 12]]], dtype=int32)
# 接着使用tf.transpose ,0,1,2代表三个维度标记
# Convert from [depth, height, width] to [height, width, depth].
# 0,1,2-----> 1, 2, 0
In [17]: tf.transpose(depth_major, [1, 2, 0]).eval()
Out[17]:
array([[[ 1, 5, 9],
[ 2, 6, 10]],
[[ 3, 7, 11],
[ 4, 8, 12]]], dtype=int32)
2 转换API
a
. Permutes the dimensions according to perm
.
2)处理图片的形状
所以在读取数据处理形状的时候
import tensorflow as tf
import os
class Cifar(object):
# 初始化
def __init__(self):
# 图像的大小
self.height = 32
self.width = 32
self.channels = 3
# 图像的字节数
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channels
self.bytes = self.label_bytes + self.image_bytes
def read_and_decode(self, file_list):
# 读取二进制文件
# print("read_and_decode:\n", file_list)
# 1、构造文件名队列
file_queue = tf.train.string_input_producer(file_list)
# 2、构造二进制文件阅读器
reader = tf.FixedLengthRecordReader(self.bytes)
key, value = reader.read(file_queue)
print("key:\n", key)
print("value:\n", value)
# 3、解码
decoded = tf.decode_raw(value, tf.uint8)
print("decoded:\n", decoded)
# 4、基本的数据处理
# 切片处理,把标签值和特征值分开
label = tf.slice(decoded, [0], [self.label_bytes])
image = tf.slice(decoded, [self.label_bytes], [self.image_bytes])
print("label:\n", label)
print("image:\n", image)
# 改变图像的形状
image_reshaped = tf.reshape(image, [self.channels, self.height, self.width])
# 转置
image_transposed = tf.transpose(image_reshaped, [1, 2, 0])
print("image_transposed:\n", image_transposed)
# 类型转换
label_cast = tf.cast(label, tf.float32)
image_cast = tf.cast(image_transposed, tf.float32)
# 5、批处理
label_batch, image_batch = tf.train.batch([label_cast, image_cast], batch_size=10, num_threads=1, capacity=10)
return label_batch, image_batch
if __name__ == "__main__":
# 构造文件名列表
file_name = os.listdir("./cifar-10-batches-bin")
print("file_name:\n", file_name)
file_list = [os.path.join("./cifar-10-batches-bin/", file) for file in file_name if file[-3:] == "bin"]
print("file_list:\n", file_list)
# 调用读取二进制文件的方法
cf = Cifar()
label, image = cf.read_and_decode(file_list)
# 开启会话
with tf.Session() as sess:
# 创建线程协调器
coord = tf.train.Coordinator()
# 创建线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 打印结果
print("label:\n", sess.run(label))
print("image:\n", sess.run(image))
# 回收资源
coord.request_stop()
coord.join(threads)